triton_backend.py 42.6 KB
Newer Older
1
2
from __future__ import annotations

3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Optional, Union
5
6

import torch
7
import triton
8
import triton.language as tl
9

10
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
from sglang.srt.layers.dp_attention import get_attention_tp_size
13
from sglang.srt.layers.radix_attention import AttentionType
14
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
16
17
18
19
20
from sglang.srt.utils import (
    get_bool_env_var,
    get_device_core_count,
    get_int_env_var,
    next_power_of_2,
)
21
22

if TYPE_CHECKING:
23
    from sglang.srt.layers.radix_attention import RadixAttention
24
    from sglang.srt.model_executor.model_runner import ModelRunner
25
    from sglang.srt.speculative.spec_info import SpecInput
26
27


28
29
30
31
32
33
34
35
def logit_capping_mod(logit_capping_method, logit_cap):
    # positive logit_cap -> tanh cap
    if logit_capping_method == "tanh":
        return logit_cap
    else:
        raise ValueError()


36
37
38
39
40
41
42
43
44
45
46
@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor
    attn_lse: torch.Tensor
    max_extend_len: int
    num_kv_splits: torch.Tensor
    kv_indptr: torch.Tensor
    kv_indices: torch.Tensor
    qo_indptr: torch.Tensor
    custom_mask: torch.Tensor
    mask_indptr: torch.Tensor
47
48
49
50
    # Sliding window
    window_kv_indptr: torch.Tensor
    window_kv_indices: torch.Tensor
    window_num_kv_splits: torch.Tensor
51
    window_kv_offsets: torch.Tensor
52
53


54
class TritonAttnBackend(AttentionBackend):
55
56
57
58
59
60
    def __init__(
        self,
        model_runner: ModelRunner,
        skip_prefill: bool = False,
        kv_indptr_buf: Optional[torch.Tensor] = None,
    ):
61
        # Lazy import to avoid the initialization of cuda context
62
63
64
65
66
67
68
69
70
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

71
72
        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
Ke Bao's avatar
Ke Bao committed
73

Lianmin Zheng's avatar
Lianmin Zheng committed
74
        # Parse args
75
        self.skip_prefill = skip_prefill
76
        max_bs = model_runner.req_to_token_pool.size
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
82
83
84
85
86
87
        self.sliding_window_size = model_runner.sliding_window_size
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
        self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
        self.speculative_num_steps = model_runner.server_args.speculative_num_steps
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
        self.num_kv_head = model_runner.model_config.get_num_kv_heads(
            get_attention_tp_size()
        )
88
        if model_runner.hybrid_gdn_config is not None:
89
90
91
92
93
94
            # For hybrid linear models, layer_id = 0 may not be full attention
            self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
        else:
            self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
                -1
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
95
96
97
98
99
100
101
        self.max_context_len = model_runner.model_config.context_len
        self.device = model_runner.device
        self.device_core_count = get_device_core_count(model_runner.gpu_id)
        self.static_kv_splits = get_bool_env_var(
            "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
        )
        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

        # Decide whether enable deterministic inference with batch-invariant operations
        self.enable_deterministic = (
            model_runner.server_args.enable_deterministic_inference
        )

        # Configure deterministic inference settings
        if self.enable_deterministic:
            # Use fixed split tile size for batch invariance
            self.split_tile_size = get_int_env_var(
                "SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
            )
            # Set static_kv_splits to False to use deterministic logic instead
            self.static_kv_splits = False
        else:
            self.split_tile_size = (
                model_runner.server_args.triton_attention_split_tile_size
            )

121
122
123
124
        if self.split_tile_size is not None:
            self.max_kv_splits = (
                self.max_context_len + self.split_tile_size - 1
            ) // self.split_tile_size
125

Lianmin Zheng's avatar
Lianmin Zheng committed
126
        # Check arguments
127
128
129
130
131
        assert not (
            model_runner.sliding_window_size is not None
            and model_runner.model_config.is_encoder_decoder
        ), "Sliding window and cross attention are not supported together"

Lianmin Zheng's avatar
Lianmin Zheng committed
132
        # Initialize buffers
133
        # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
134
135
136
137
138
139
140
        if kv_indptr_buf is None:
            self.kv_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )
        else:
            self.kv_indptr = kv_indptr_buf

141
142
143
144
145
146
147
148
149
150
151
152
        # If sliding window is enabled, we might need two sets of buffers
        # because of interleaved attention types (e.g. for Gemma3)
        self.window_kv_indptr = None
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indptr_buf is None:
                self.window_kv_indptr = torch.zeros(
                    (max_bs + 1,), dtype=torch.int32, device=model_runner.device
                )
            else:
                # When provided a buffer, create a clone for the second buffer
                self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)

153
154
155
156
157
158
159
160
        if not self.skip_prefill:
            self.qo_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )

            self.mask_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int64, device=model_runner.device
            )
161

Lianmin Zheng's avatar
Lianmin Zheng committed
162
        # Initialize forward metadata
163
        self.forward_metadata: ForwardMetadata = None
164

165
166
167
168
169
    def get_num_kv_splits(
        self,
        num_kv_splits: torch.Tensor,
        seq_lens: torch.Tensor,
    ):
170
        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
171
172
173
        # NOTE(alcanderian): Considering speculative_decodeing,
        # num_kv_splits.shape[0] will be topk * real_num_token.
        # And the real_num_token is num_seq in decoding phase.
174
175
176
177
178
179
        num_group = num_token // num_seq

        assert (
            num_group * num_seq == num_token
        ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"

180
181
182
183
        # Legacy dynamic splitting logic (non-deterministic)
        if (
            self.static_kv_splits or self.device_core_count <= 0
        ) and not self.enable_deterministic:
184
185
186
            num_kv_splits.fill_(self.max_kv_splits)
            return

187
188
189
190
191
192
193
194
        # deterministic
        if self.split_tile_size is not None and self.enable_deterministic:
            # expand seq_lens to match num_token
            if num_group > 1:
                expanded_seq_lens = seq_lens.repeat_interleave(num_group)
            else:
                expanded_seq_lens = seq_lens

195
            num_kv_splits[:] = (
196
                expanded_seq_lens + self.split_tile_size - 1
197
198
199
            ) // self.split_tile_size
            return

200
201
202
203
204
        if num_seq < 256:
            SCHEDULE_SEQ = 256
        else:
            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)

205
206
207
        get_num_kv_splits_triton[(1,)](
            num_kv_splits,
            seq_lens,
208
209
            num_seq,
            num_group,
210
            self.num_head,
211
            self.num_kv_head,
212
213
            self.max_kv_splits,
            self.device_core_count,
214
            MAX_NUM_SEQ=SCHEDULE_SEQ,
215
        )
216

217
218
219
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

220
221
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr
222
223
224
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
225
        window_kv_offsets = None
226
227
228
229
230
231
        spec_info = forward_batch.spec_info

        if forward_batch.forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
232
                kv_indices = torch.empty(
233
                    forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
234
235
236
237
238
239
240
241
242
243
                )
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
244
245
246
247
248
                # Sliding window
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
249
                    window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
250
251
252
253
254
255
256
257
                        update_sliding_window_buffer(
                            self.window_kv_indptr,
                            self.req_to_token,
                            self.sliding_window_size,
                            forward_batch.seq_lens,
                            forward_batch.req_pool_indices,
                            bs,
                            self.device,
258
                            self.token_to_kv_pool_allocator,
259
260
261
262
263
264
                        )
                    )
                    window_num_kv_splits = torch.empty(
                        (bs,), dtype=torch.int32, device=self.device
                    )
                    self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)
265
266
267
268
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
                bs = kv_indptr.shape[0] - 1

269
270
271
272
273
274
275
276
277
278
            attn_logits = torch.empty(
                (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
                dtype=torch.float32,
                device=self.device,
            )
            attn_lse = torch.empty(
                (bs, self.num_head, self.max_kv_splits),
                dtype=torch.float32,
                device=self.device,
            )
279
            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
280
            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
281

282
283
284
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
285
            max_extend_len = None
286
287
288
289
290
291
292
293
294
295
        elif forward_batch.forward_mode.is_target_verify():
            bs = len(forward_batch.req_pool_indices)
            qo_indptr = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            # Different with flashinfer kv_indptr and kv_indices construction
296
297
            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
298
            kv_indices = torch.empty(
299
                kv_indptr[-1], dtype=torch.int64, device=self.device
300
301
            )
            create_flashinfer_kv_indices_triton[(bs,)](
302
                self.req_to_token,
303
304
305
306
307
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
308
                self.req_to_token.stride(0),
309
310
            )

311
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
                # window_kv_offsets is used to calculate the start position in custom mask
                (
                    window_kv_indptr,
                    window_kv_indices,
                    window_kv_lens,
                    window_kv_offsets,
                ) = update_sliding_window_buffer(
                    self.window_kv_indptr,
                    self.req_to_token,
                    self.sliding_window_size,
                    forward_batch.seq_lens,
                    forward_batch.req_pool_indices,
                    bs,
                    self.device,
                    self.token_to_kv_pool_allocator,
327
328
                )

329
330
331
332
333
334
335
336
            custom_mask = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (
                forward_batch.seq_lens + self.num_draft_tokens
            )
            mask_indptr = self.mask_indptr
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
            mask_indptr = mask_indptr[: bs + 1]
            max_extend_len = self.num_draft_tokens
337
            num_kv_splits = None
338
            attn_logits = None
339
            attn_lse = None
340

341
342
343
344
345
        elif forward_batch.forward_mode.is_draft_extend():
            kv_indices, kv_indptr, qo_indptr, custom_mask = (
                spec_info.generate_attn_arg_prefill(
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
346
                    None,
347
348
349
                    self.req_to_token,
                )
            )
350
            kv_indices = kv_indices.to(torch.int64)
351
            mask_indptr = None
352
353
354
            # TODO(FIXME): This will trigger an invalid Eagle tree when using
            # `max(spec_info.accept_length_cpu)`.
            # It might have been forgotten to update somewhere.
355
            max_extend_len = torch.max(spec_info.accept_length).item()
356
            num_kv_splits = None
357
            attn_logits = None
358
            attn_lse = None
359
        else:
360
361
362
363
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
364
            kv_indices = torch.empty(
365
                forward_batch.extend_prefix_lens.sum().item(),
366
                dtype=torch.int64,
367
368
369
370
371
372
373
374
375
376
377
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
378
379
            # Sliding window
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
380
381
382
383
384
385
386
387
388
389
390
                window_kv_indptr, window_kv_indices, _, _ = (
                    update_sliding_window_buffer(
                        self.window_kv_indptr,
                        self.req_to_token,
                        self.sliding_window_size,
                        forward_batch.extend_prefix_lens,
                        forward_batch.req_pool_indices,
                        bs,
                        self.device,
                        self.token_to_kv_pool_allocator,
                    )
391
                )
392
393
394
395
396

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
397
            mask_indptr = None
398
            attn_logits = None
399
            attn_lse = None
Lianmin Zheng's avatar
Lianmin Zheng committed
400
            max_extend_len = max(forward_batch.extend_seq_lens_cpu)
401
            num_kv_splits = None
402

403
        self.forward_metadata = ForwardMetadata(
404
            attn_logits,
405
            attn_lse,
406
            max_extend_len,
407
            num_kv_splits,
408
409
410
411
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
412
            mask_indptr,
413
414
415
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
416
            window_kv_offsets,
417
        )
418

419
    def init_cuda_graph_state(
420
421
422
423
        self,
        max_bs: int,
        max_num_tokens: int,
        kv_indices_buf: Optional[torch.Tensor] = None,
424
    ):
425
        self.cuda_graph_attn_logits = torch.zeros(
426
            (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
427
428
429
430
            dtype=torch.float32,
            device=self.device,
        )
        self.cuda_graph_attn_lse = torch.zeros(
431
            (max_num_tokens, self.num_head, self.max_kv_splits),
432
433
434
            dtype=torch.float32,
            device=self.device,
        )
435
        self.cuda_graph_num_kv_splits = torch.full(
436
            (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
437
        )
438
439
        if kv_indices_buf is None:
            self.cuda_graph_kv_indices = torch.zeros(
440
                (max_num_tokens * self.max_context_len),
441
                dtype=torch.int64,
442
443
444
445
446
447
448
                device=self.device,
            )
        else:
            self.cuda_graph_kv_indices = kv_indices_buf

        if not self.skip_prefill:
            self.cuda_graph_custom_mask = torch.zeros(
449
                (max_num_tokens * self.max_context_len),
450
451
452
                dtype=torch.uint8,
                device=self.device,
            )
453

454
455
456
        if self.sliding_window_size is not None and self.sliding_window_size > 0:
            if kv_indices_buf is None:
                self.cuda_graph_window_kv_indices = torch.zeros(
457
                    (max_num_tokens * self.sliding_window_size),
458
                    dtype=torch.int64,
459
460
461
462
463
464
                    device=self.device,
                )
            else:
                self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)

            self.cuda_graph_window_num_kv_splits = torch.full(
465
466
467
468
                (max_num_tokens,),
                self.max_kv_splits,
                dtype=torch.int32,
                device=self.device,
469
470
            )

471
472
473
474
475
476
            self.cuda_graph_window_kv_offsets = torch.zeros(
                (max_bs,),
                dtype=torch.int32,
                device=self.device,
            )

477
    def init_forward_metadata_capture_cuda_graph(
478
479
        self,
        bs: int,
480
        num_tokens: int,
481
482
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
483
        encoder_lens: Optional[torch.Tensor],
484
        forward_mode: ForwardMode,
485
        spec_info: Optional[SpecInput],
486
    ):
487
        assert encoder_lens is None, "Not supported"
488
489
490
        window_kv_indptr = self.window_kv_indptr
        window_kv_indices = None
        window_num_kv_splits = None
491
        window_kv_offsets = None
492

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
        if forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr = self.kv_indptr
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                kv_indices = self.cuda_graph_kv_indices
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices,
                    seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
508
509
510
511
512
513
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_kv_indices = self.cuda_graph_window_kv_indices
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
514
                    window_kv_indptr, window_kv_indices, _, _ = (
515
516
517
518
519
520
521
522
523
524
                        update_sliding_window_buffer_cuda_graph(
                            self.window_kv_indptr,
                            window_kv_indices,
                            self.req_to_token,
                            self.sliding_window_size,
                            seq_lens[:bs],
                            req_pool_indices,
                            bs,
                            self.token_to_kv_pool_allocator,
                        )
525
                    )
526
527
528
529
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

            attn_logits = self.cuda_graph_attn_logits
530
            attn_lse = self.cuda_graph_attn_lse
531
            max_extend_len = None
532
            num_kv_splits = self.cuda_graph_num_kv_splits
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
        elif forward_mode.is_target_verify():
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

558
559
560
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_kv_indices = self.cuda_graph_window_kv_indices
                window_num_kv_splits = self.cuda_graph_window_num_kv_splits
561
562
                window_kv_offsets = self.cuda_graph_window_kv_offsets
                window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
563
564
565
566
567
                    update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
568
                        seq_lens[:bs],
569
570
571
572
573
574
                        req_pool_indices,
                        bs,
                        self.token_to_kv_pool_allocator,
                    )
                )

575
            custom_mask = self.cuda_graph_custom_mask
576
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
577
578
579
580
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
            max_extend_len = self.num_draft_tokens
581
            num_kv_splits = None
582
            attn_logits = None
583
            attn_lse = None
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        elif forward_mode.is_draft_extend():
            num_tokens_per_bs = self.speculative_num_steps + 1
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                bs * num_tokens_per_bs + 1,
                step=num_tokens_per_bs,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
            custom_mask = None
            mask_indptr = None
            max_extend_len = num_tokens_per_bs
            num_kv_splits = None
            attn_logits = None
            attn_lse = None
612
613
614
615
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
            )
616

617
        self.forward_metadata = ForwardMetadata(
618
            attn_logits,
619
            attn_lse,
620
            max_extend_len,
621
            num_kv_splits,
622
623
            kv_indptr,
            kv_indices,
624
625
626
            qo_indptr,
            custom_mask,
            mask_indptr,
627
628
629
            window_kv_indptr,
            window_kv_indices,
            window_num_kv_splits,
630
            window_kv_offsets,
631
632
633
        )

    def init_forward_metadata_replay_cuda_graph(
634
635
636
637
638
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
639
        encoder_lens: Optional[torch.Tensor],
640
        forward_mode: ForwardMode,
641
        spec_info: Optional[SpecInput],
642
        seq_lens_cpu: Optional[torch.Tensor],
643
    ):
644
        # NOTE: encoder_lens expected to be zeros or None
645
646
647
648
        if forward_mode.is_decode_or_idle():
            # Update kv_indptr, kv_indices
            kv_indptr = self.kv_indptr
            kv_indices = self.cuda_graph_kv_indices
649
            num_kv_splits = self.cuda_graph_num_kv_splits
650
651
652
653
654
655
656
657
658
659
660
661
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices[:bs],
                    seq_lens[:bs],
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
662
                num_token = bs
663
664
665
666
667
668
                if (
                    self.sliding_window_size is not None
                    and self.sliding_window_size > 0
                ):
                    window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                    window_kv_indices = self.cuda_graph_window_kv_indices
669
                    _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
670
671
672
673
674
675
676
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices[:bs],
                        bs,
677
                        self.token_to_kv_pool_allocator,
678
679
680
681
682
                    )
                    self.get_num_kv_splits(
                        window_num_kv_splits[:num_token], window_kv_lens[:bs]
                    )

683
684
685
            else:
                kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
                kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
686
687
                num_token = spec_info.kv_indptr.shape[0] - 1
            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
688

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        elif forward_mode.is_target_verify():
            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
            bs = len(req_pool_indices)
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
712
713
714
            if self.sliding_window_size is not None and self.sliding_window_size > 0:
                window_num_kv_splits = self.cuda_graph_window_num_kv_splits
                window_kv_indices = self.cuda_graph_window_kv_indices
715
716
717
718
719
720
721
722
723
724
725
726
                window_kv_offsets = self.cuda_graph_window_kv_offsets
                _, _, window_kv_lens, window_kv_offsets[:bs] = (
                    update_sliding_window_buffer_cuda_graph(
                        self.window_kv_indptr,
                        window_kv_indices,
                        self.req_to_token,
                        self.sliding_window_size,
                        seq_lens[:bs],
                        req_pool_indices,
                        bs,
                        self.token_to_kv_pool_allocator,
                    )
727
                )
728
729
730
731
732
            custom_mask = self.cuda_graph_custom_mask
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        elif forward_mode.is_draft_extend():
            seq_lens = seq_lens[:bs]
            accept_lens = spec_info.accept_length[:bs]
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
750
751
752
753
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
            )
754

755
756
757
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

758
    def forward_extend(
759
        self,
760
761
762
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
763
764
765
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
Ke Bao's avatar
Ke Bao committed
766
        sinks=None,
767
    ):
768
769
770
771
772
773
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

774
775
776
777
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
778

779
780
        logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)

781
782
783
784
        causal = True
        if layer.attn_type == AttentionType.ENCODER_ONLY:
            causal = False

785
786
787
788
789
790
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            sliding_window_size = (
                layer.sliding_window_size
            )  # Needed for sliding window mask
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
791
            window_kv_offsets = self.forward_metadata.window_kv_offsets
792
793
794
795
        else:
            sliding_window_size = -1
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices
796
            window_kv_offsets = None
797

798
799
800
801
802
803
804
        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
805
            self.forward_metadata.qo_indptr,
806
807
            kv_indptr,
            kv_indices,
808
            self.forward_metadata.custom_mask,
809
            causal,
810
811
            self.forward_metadata.mask_indptr,
            self.forward_metadata.max_extend_len,
812
            layer.scaling,
813
            logit_cap=logits_soft_cap,
814
            sliding_window_size=sliding_window_size,
Ke Bao's avatar
Ke Bao committed
815
            sinks=sinks,
816
            window_kv_offsets=window_kv_offsets,
817
            xai_temperature_len=layer.xai_temperature_len,
818
819
820
        )
        return o

821
    def forward_decode(
822
        self,
823
824
825
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
826
827
828
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
Ke Bao's avatar
Ke Bao committed
829
        sinks=None,
830
    ):
831
832
833
834
835
836
837
838
839
840
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

841
842
        logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)

843
844
845
846
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
847

848
849
850
851
852
853
854
        if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
            kv_indptr = self.forward_metadata.window_kv_indptr
            kv_indices = self.forward_metadata.window_kv_indices
        else:
            kv_indptr = self.forward_metadata.kv_indptr
            kv_indices = self.forward_metadata.kv_indices

855
856
857
858
859
        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
860
861
            kv_indptr,
            kv_indices,
862
863
864
            self.forward_metadata.attn_logits,
            self.forward_metadata.attn_lse,
            self.forward_metadata.num_kv_splits,
865
            self.max_kv_splits,
866
            layer.scaling,
867
            logit_cap=logits_soft_cap,
Ke Bao's avatar
Ke Bao committed
868
            sinks=sinks,
869
            xai_temperature_len=layer.xai_temperature_len,
870
871
        )
        return o
872
873
874
875
876
877
878
879
880
881
882
883
884
885


class TritonMultiStepDraftBackend:
    """
    Wrap multiple triton attention backends as one for multiple consecutive
    draft decoding steps.
    """

    def __init__(
        self,
        model_runner: ModelRunner,
        topk: int,
        speculative_num_steps: int,
    ):
886
        from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
887
888
889
890

        self.topk = topk
        self.speculative_num_steps = speculative_num_steps
        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
891
        max_bs = model_runner.req_to_token_pool.size * self.topk
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
        self.kv_indptr = torch.zeros(
            (
                self.speculative_num_steps,
                max_bs + 1,
            ),
            dtype=torch.int32,
            device=model_runner.device,
        )
        self.attn_backends = []
        for i in range(self.speculative_num_steps):
            self.attn_backends.append(
                TritonAttnBackend(
                    model_runner,
                    skip_prefill=True,
                    kv_indptr_buf=self.kv_indptr[i],
                )
            )
        self.max_context_len = self.attn_backends[0].max_context_len
910
911
912
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
913
        self.device = model_runner.device
914
915
        # Cached variables for generate_draft_decode_kv_indices
        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
916
        self.page_size = model_runner.server_args.page_size
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936

    def common_template(
        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
    ):
        num_seqs = forward_batch.batch_size
        bs = self.topk * num_seqs
        seq_lens_sum = forward_batch.seq_lens_sum

        self.generate_draft_decode_kv_indices[
            (self.speculative_num_steps, num_seqs, self.topk)
        ](
            forward_batch.req_pool_indices,
            forward_batch.req_to_token_pool.req_to_token,
            forward_batch.seq_lens,
            kv_indices_buffer,
            self.kv_indptr,
            forward_batch.positions,
            self.pool_len,
            kv_indices_buffer.shape[1],
            self.kv_indptr.shape[1],
937
938
939
            next_power_of_2(num_seqs),
            next_power_of_2(self.speculative_num_steps),
            next_power_of_2(bs),
940
            self.page_size,
941
942
943
944
945
946
947
948
949
950
        )

        for i in range(self.speculative_num_steps):
            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
                : seq_lens_sum * self.topk + bs * (i + 1)
            ]
            call_fn(i, forward_batch)

    def init_forward_metadata(self, forward_batch: ForwardBatch):
951
        kv_indices = torch.empty(
952
953
954
955
            (
                self.speculative_num_steps,
                forward_batch.batch_size * self.topk * self.max_context_len,
            ),
956
            dtype=torch.int64,
957
            device=self.device,
958
959
960
961
962
963
964
965
966
967
968
969
970
        )

        def call_fn(i, forward_batch):
            forward_batch.spec_info.kv_indptr = (
                forward_batch.spec_info.kv_indptr.clone()
            )
            forward_batch.spec_info.kv_indices = (
                forward_batch.spec_info.kv_indices.clone()
            )
            self.attn_backends[i].init_forward_metadata(forward_batch)

        self.common_template(forward_batch, kv_indices, call_fn)

971
    def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
972
        self.cuda_graph_kv_indices = torch.zeros(
973
            (self.speculative_num_steps, max_num_tokens * self.max_context_len),
974
            dtype=torch.int64,
975
            device=self.device,
976
977
978
        )
        for i in range(self.speculative_num_steps):
            self.attn_backends[i].init_cuda_graph_state(
979
                max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
            )

    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
                forward_batch.batch_size,
                forward_batch.batch_size * self.topk,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

996
997
998
    def init_forward_metadata_replay_cuda_graph(
        self, forward_batch: ForwardBatch, bs: int
    ):
999
1000
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1001
                bs,
1002
1003
1004
1005
1006
1007
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                seq_lens_sum=-1,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
1008
                seq_lens_cpu=None,
1009
1010
1011
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073


@triton.jit
def get_num_kv_splits_triton(
    num_kv_splits_ptr,
    seq_lens_ptr,
    num_seq,
    num_group,
    num_head,
    num_kv_head,
    max_kv_splits,
    device_core_count,
    MAX_NUM_SEQ: tl.constexpr,
):
    # TODO: this method is tunable, we need more online serving data to tune it
    offs_seq = tl.arange(0, MAX_NUM_SEQ)
    mask_seq = offs_seq < num_seq

    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
    max_seq_len = tl.max(seq_lens)
    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
    min_seq_len = tl.min(seq_lens)
    if max_seq_len * 8 < min_seq_len * 10:
        min_seq_len = max_seq_len
    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)

    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
    ext_device_core_count = tl.cast(
        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
    )
    block_h, num_kv_group = 16, num_head // num_kv_head
    if num_kv_group == 1:
        token_grid = num_seq * num_group * num_head
    else:
        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
        block_h = tl.minimum(block_h, num_kv_group)
        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
    max_kv_splits_2 = tl.minimum(
        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
    )
    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)

    num_kv_splits = tl.maximum(
        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
    )

    offs_token = offs_seq * num_group
    mask_token = offs_token < num_seq * num_group
    for i in range(0, num_group):
        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)


def update_sliding_window_buffer(
    window_kv_indptr,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
    device,
1074
    token_to_kv_pool_allocator=None,
1075
1076
1077
):
    window_kv_lens = torch.minimum(
        seq_lens,
1078
        torch.tensor(sliding_window_size),
1079
1080
1081
1082
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_indices = torch.empty(
1083
        window_kv_indptr[-1], dtype=torch.int64, device=device
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    )
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
1095
1096
1097
1098
1099
1100
1101
1102
    # full to swa index mapping
    if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
        kv_last_index = window_kv_indptr[-1]
        window_kv_indices[:kv_last_index] = (
            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                window_kv_indices[:kv_last_index]
            )
        )
1103
    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113


def update_sliding_window_buffer_cuda_graph(
    window_kv_indptr,
    window_kv_indices,
    req_to_token,
    sliding_window_size,
    seq_lens,
    req_pool_indices,
    bs,
1114
    token_to_kv_pool_allocator=None,
1115
1116
1117
):
    window_kv_lens = torch.minimum(
        seq_lens,
1118
        torch.tensor(sliding_window_size),
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
    )
    window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
    window_kv_indptr = window_kv_indptr[: bs + 1]
    window_kv_start_idx = seq_lens - window_kv_lens
    create_flashinfer_kv_indices_triton[(bs,)](
        req_to_token,
        req_pool_indices,
        window_kv_lens,
        window_kv_indptr,
        window_kv_start_idx,
        window_kv_indices,
        req_to_token.stride(0),
    )
1132
1133
1134
1135
1136
1137
1138
1139
    # full to swa index mapping
    if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
        kv_last_index = window_kv_indptr[-1]
        window_kv_indices[:kv_last_index] = (
            token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
                window_kv_indices[:kv_last_index]
            )
        )
1140
    return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx