triton_backend.py 26.1 KB
Newer Older
1
2
from __future__ import annotations

3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Optional, Union
5
6

import torch
7
import triton
8
import triton.language as tl
9

10
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
from sglang.srt.layers.dp_attention import get_attention_tp_size
13
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
from sglang.srt.utils import get_bool_env_var, get_device_core_count
15
16

if TYPE_CHECKING:
17
    from sglang.srt.layers.radix_attention import RadixAttention
18
    from sglang.srt.model_executor.model_runner import ModelRunner
19
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
20
21


22
23
24
25
@triton.jit
def get_num_kv_splits_triton(
    num_kv_splits_ptr,
    seq_lens_ptr,
26
27
    num_seq,
    num_group,
28
29
30
31
    num_head,
    num_kv_head,
    max_kv_splits,
    device_core_count,
32
    MAX_NUM_SEQ: tl.constexpr,
33
):
34
35
36
    # TODO: this method is tunable, we need more online serving data to tune it
    offs_seq = tl.arange(0, MAX_NUM_SEQ)
    mask_seq = offs_seq < num_seq
37

38
    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
39
    max_seq_len = tl.max(seq_lens)
40
    seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
41
42
43
44
45
46
47
    min_seq_len = tl.min(seq_lens)
    if max_seq_len * 8 < min_seq_len * 10:
        min_seq_len = max_seq_len
    max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
    kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)

    # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
48
49
50
    ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
    ext_device_core_count = tl.cast(
        device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
51
52
53
    )
    block_h, num_kv_group = 16, num_head // num_kv_head
    if num_kv_group == 1:
54
        token_grid = num_seq * num_group * num_head
55
56
57
    else:
        # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
        block_h = tl.minimum(block_h, num_kv_group)
58
59
60
61
        token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
    max_kv_splits_2 = tl.minimum(
        tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
    )
62
63
64
65
66
    kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)

    num_kv_splits = tl.maximum(
        tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
    )
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    offs_token = offs_seq * num_group
    mask_token = offs_token < num_seq * num_group
    for i in range(0, num_group):
        tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)


@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor
    attn_lse: torch.Tensor
    max_extend_len: int
    num_kv_splits: torch.Tensor
    kv_indptr: torch.Tensor
    kv_indices: torch.Tensor
    qo_indptr: torch.Tensor
    custom_mask: torch.Tensor
    mask_indptr: torch.Tensor
85
86


87
class TritonAttnBackend(AttentionBackend):
88
89
90
91
92
93
    def __init__(
        self,
        model_runner: ModelRunner,
        skip_prefill: bool = False,
        kv_indptr_buf: Optional[torch.Tensor] = None,
    ):
94
95
96
97
98
99
100
101
102
103
104
105
        # Lazy import to avoid the initialization of cuda context
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

        self.decode_attention_fwd = decode_attention_fwd
        self.extend_attention_fwd = extend_attention_fwd
Ke Bao's avatar
Ke Bao committed
106

107
108
        self.skip_prefill = skip_prefill

109
        max_bs = model_runner.req_to_token_pool.size
110
111
112
113
114
115
116
117

        if kv_indptr_buf is None:
            self.kv_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )
        else:
            self.kv_indptr = kv_indptr_buf

118
119
        self.req_to_token = model_runner.req_to_token_pool.req_to_token

120
121
122
123
124
125
126
127
        if not self.skip_prefill:
            self.qo_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )

            self.mask_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int64, device=model_runner.device
            )
128
129
130

        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens

131
132
133
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
134
135
136
        self.num_kv_head = model_runner.model_config.get_num_kv_heads(
            get_attention_tp_size()
        )
137

138
139
140
141
        self.static_kv_splits = get_bool_env_var(
            "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
        )
        self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
142
143
        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

144
        self.forward_metadata: ForwardMetadata = None
145

146
        self.max_context_len = model_runner.model_config.context_len
147

148
        self.device = model_runner.device
149
150
151
152
153
154
155
        self.device_core_count = get_device_core_count(model_runner.gpu_id)

    def get_num_kv_splits(
        self,
        num_kv_splits: torch.Tensor,
        seq_lens: torch.Tensor,
    ):
156
157
158
159
160
161
162
163
        num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
        num_group = num_token // num_seq

        assert (
            num_group * num_seq == num_token
        ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"

        if self.static_kv_splits or self.device_core_count <= 0:
164
165
166
            num_kv_splits.fill_(self.max_kv_splits)
            return

167
168
169
170
171
        if num_seq < 256:
            SCHEDULE_SEQ = 256
        else:
            SCHEDULE_SEQ = triton.next_power_of_2(num_seq)

172
173
174
        get_num_kv_splits_triton[(1,)](
            num_kv_splits,
            seq_lens,
175
176
            num_seq,
            num_group,
177
            self.num_head,
178
            self.num_kv_head,
179
180
            self.max_kv_splits,
            self.device_core_count,
181
            MAX_NUM_SEQ=SCHEDULE_SEQ,
182
        )
183

184
185
186
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

187
188
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr
189
190
191
192
193
194
        spec_info = forward_batch.spec_info

        if forward_batch.forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
195
                kv_indices = torch.empty(
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                    forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
                )
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
                bs = kv_indptr.shape[0] - 1

211
212
213
214
215
216
217
218
219
220
            attn_logits = torch.empty(
                (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
                dtype=torch.float32,
                device=self.device,
            )
            attn_lse = torch.empty(
                (bs, self.num_head, self.max_kv_splits),
                dtype=torch.float32,
                device=self.device,
            )
221
222
            num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)

223
            self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
224

225
226
227
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
228
            max_extend_len = None
229
230
231
232
233
234
235
236
237
238
        elif forward_batch.forward_mode.is_target_verify():
            bs = len(forward_batch.req_pool_indices)
            qo_indptr = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            # Different with flashinfer kv_indptr and kv_indices construction
239
240
            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
241
            kv_indices = torch.empty(
242
                kv_indptr[-1], dtype=torch.int32, device=self.device
243
244
            )
            create_flashinfer_kv_indices_triton[(bs,)](
245
                self.req_to_token,
246
247
248
249
250
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
251
                self.req_to_token.stride(0),
252
253
            )

254
255
256
257
258
259
260
261
            custom_mask = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (
                forward_batch.seq_lens + self.num_draft_tokens
            )
            mask_indptr = self.mask_indptr
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
            mask_indptr = mask_indptr[: bs + 1]
            max_extend_len = self.num_draft_tokens
262
            num_kv_splits = None
263
            attn_logits = None
264
            attn_lse = None
265
266
267
268
269
        elif forward_batch.forward_mode.is_draft_extend():
            kv_indices, kv_indptr, qo_indptr, custom_mask = (
                spec_info.generate_attn_arg_prefill(
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
270
                    None,
271
272
273
274
                    self.req_to_token,
                )
            )
            mask_indptr = None
275
276
277
            # TODO(FIXME): This will trigger an invalid Eagle tree when using
            # `max(spec_info.accept_length_cpu)`.
            # It might have been forgotten to update somewhere.
278
            max_extend_len = torch.max(spec_info.accept_length).item()
279
            num_kv_splits = None
280
            attn_logits = None
281
            attn_lse = None
282
        else:
283
284
285
286
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
287
            kv_indices = torch.empty(
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                forward_batch.extend_prefix_lens.sum().item(),
                dtype=torch.int32,
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
306
            mask_indptr = None
307
            attn_logits = None
308
            attn_lse = None
309
            max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
310
            num_kv_splits = None
311

312
        self.forward_metadata = ForwardMetadata(
313
            attn_logits,
314
            attn_lse,
315
            max_extend_len,
316
            num_kv_splits,
317
318
319
320
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
321
            mask_indptr,
322
        )
323

324
325
326
    def init_cuda_graph_state(
        self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
    ):
327
328
329
330
331
332
333
334
335
336
        self.cuda_graph_attn_logits = torch.zeros(
            (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
            dtype=torch.float32,
            device=self.device,
        )
        self.cuda_graph_attn_lse = torch.zeros(
            (max_bs, self.num_head, self.max_kv_splits),
            dtype=torch.float32,
            device=self.device,
        )
337
338
        self.cuda_graph_num_kv_splits = torch.full(
            (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
339
        )
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        if kv_indices_buf is None:
            self.cuda_graph_kv_indices = torch.zeros(
                (max_bs * self.max_context_len),
                dtype=torch.int32,
                device=self.device,
            )
        else:
            self.cuda_graph_kv_indices = kv_indices_buf

        if not self.skip_prefill:
            self.cuda_graph_custom_mask = torch.zeros(
                (max_bs * self.max_context_len),
                dtype=torch.uint8,
                device=self.device,
            )
355
356

    def init_forward_metadata_capture_cuda_graph(
357
358
        self,
        bs: int,
359
        num_tokens: int,
360
361
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
362
        encoder_lens: Optional[torch.Tensor],
363
        forward_mode: ForwardMode,
364
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
365
    ):
366
367
        assert encoder_lens is None, "Not supported"

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        if forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr = self.kv_indptr
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                kv_indices = self.cuda_graph_kv_indices
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices,
                    seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

            attn_logits = self.cuda_graph_attn_logits
387
            attn_lse = self.cuda_graph_attn_lse
388
            max_extend_len = None
389
            num_kv_splits = self.cuda_graph_num_kv_splits
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
        elif forward_mode.is_target_verify():
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

            custom_mask = self.cuda_graph_custom_mask
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
            max_extend_len = self.num_draft_tokens
420
            num_kv_splits = None
421
            attn_logits = None
422
            attn_lse = None
423
424
425
426
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
            )
427

428
        self.forward_metadata = ForwardMetadata(
429
            attn_logits,
430
            attn_lse,
431
            max_extend_len,
432
            num_kv_splits,
433
434
            kv_indptr,
            kv_indices,
435
436
437
            qo_indptr,
            custom_mask,
            mask_indptr,
438
439
440
        )

    def init_forward_metadata_replay_cuda_graph(
441
442
443
444
445
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
446
        encoder_lens: Optional[torch.Tensor],
447
        forward_mode: ForwardMode,
448
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
449
        seq_lens_cpu: Optional[torch.Tensor],
450
    ):
451
        # NOTE: encoder_lens expected to be zeros or None
452
453
454
455
        if forward_mode.is_decode_or_idle():
            # Update kv_indptr, kv_indices
            kv_indptr = self.kv_indptr
            kv_indices = self.cuda_graph_kv_indices
456
            num_kv_splits = self.cuda_graph_num_kv_splits
457
458
459
460
461
462
463
464
465
466
467
468
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    req_pool_indices[:bs],
                    seq_lens[:bs],
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
469
                num_token = bs
470
471
472
            else:
                kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
                kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
473
474
                num_token = spec_info.kv_indptr.shape[0] - 1
            self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        elif forward_mode.is_target_verify():
            # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
            bs = len(req_pool_indices)
            qo_indptr = self.qo_indptr[: bs + 1]
            qo_indptr[: bs + 1] = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            kv_indptr = self.kv_indptr[: bs + 1]
            kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
            kv_indices = self.cuda_graph_kv_indices
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                seq_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )
            custom_mask = self.cuda_graph_custom_mask
            custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
            mask_indptr = self.mask_indptr[: bs + 1]
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
        else:
            raise ValueError(
                f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
            )
507

508
509
510
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

511
    def forward_extend(
512
        self,
513
514
515
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
516
517
518
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
519
    ):
520
521
522
523
524
525
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

526
527
528
529
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
530
531
532
533
534
535
536
537

        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
538
539
540
541
542
543
            self.forward_metadata.qo_indptr,
            self.forward_metadata.kv_indptr,
            self.forward_metadata.kv_indices,
            self.forward_metadata.custom_mask,
            self.forward_metadata.mask_indptr,
            self.forward_metadata.max_extend_len,
544
545
546
547
548
            layer.scaling,
            layer.logit_cap,
        )
        return o

549
    def forward_decode(
550
        self,
551
552
553
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
554
555
556
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
557
    ):
558
559
560
561
562
563
564
565
566
567
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

568
569
570
571
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
572
573
574
575
576
577

        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
578
579
580
581
582
            self.forward_metadata.kv_indptr,
            self.forward_metadata.kv_indices,
            self.forward_metadata.attn_logits,
            self.forward_metadata.attn_lse,
            self.forward_metadata.num_kv_splits,
583
            self.max_kv_splits,
584
585
586
587
            layer.scaling,
            layer.logit_cap,
        )
        return o
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606


class TritonMultiStepDraftBackend:
    """
    Wrap multiple triton attention backends as one for multiple consecutive
    draft decoding steps.
    """

    def __init__(
        self,
        model_runner: ModelRunner,
        topk: int,
        speculative_num_steps: int,
    ):
        from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices

        self.topk = topk
        self.speculative_num_steps = speculative_num_steps
        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
607
        max_bs = model_runner.req_to_token_pool.size * self.topk
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        self.kv_indptr = torch.zeros(
            (
                self.speculative_num_steps,
                max_bs + 1,
            ),
            dtype=torch.int32,
            device=model_runner.device,
        )
        self.attn_backends = []
        for i in range(self.speculative_num_steps):
            self.attn_backends.append(
                TritonAttnBackend(
                    model_runner,
                    skip_prefill=True,
                    kv_indptr_buf=self.kv_indptr[i],
                )
            )
        self.max_context_len = self.attn_backends[0].max_context_len
626
627
628
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
629
        self.device = model_runner.device
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        # Cached variables for generate_draft_decode_kv_indices
        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]

    def common_template(
        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
    ):
        num_seqs = forward_batch.batch_size
        bs = self.topk * num_seqs
        seq_lens_sum = forward_batch.seq_lens_sum

        self.generate_draft_decode_kv_indices[
            (self.speculative_num_steps, num_seqs, self.topk)
        ](
            forward_batch.req_pool_indices,
            forward_batch.req_to_token_pool.req_to_token,
            forward_batch.seq_lens,
            kv_indices_buffer,
            self.kv_indptr,
            forward_batch.positions,
            num_seqs,
            self.topk,
            self.pool_len,
            kv_indices_buffer.shape[1],
            self.kv_indptr.shape[1],
            triton.next_power_of_2(num_seqs),
            triton.next_power_of_2(self.speculative_num_steps),
            triton.next_power_of_2(bs),
        )

        for i in range(self.speculative_num_steps):
            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
                : seq_lens_sum * self.topk + bs * (i + 1)
            ]
            call_fn(i, forward_batch)

    def init_forward_metadata(self, forward_batch: ForwardBatch):
667
        kv_indices = torch.empty(
668
669
670
671
672
            (
                self.speculative_num_steps,
                forward_batch.batch_size * self.topk * self.max_context_len,
            ),
            dtype=torch.int32,
673
            device=self.device,
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        )

        def call_fn(i, forward_batch):
            forward_batch.spec_info.kv_indptr = (
                forward_batch.spec_info.kv_indptr.clone()
            )
            forward_batch.spec_info.kv_indices = (
                forward_batch.spec_info.kv_indices.clone()
            )
            self.attn_backends[i].init_forward_metadata(forward_batch)

        self.common_template(forward_batch, kv_indices, call_fn)

    def init_cuda_graph_state(self, max_bs: int):
        self.cuda_graph_kv_indices = torch.zeros(
            (self.speculative_num_steps, max_bs * self.max_context_len),
            dtype=torch.int32,
691
            device=self.device,
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        )
        for i in range(self.speculative_num_steps):
            self.attn_backends[i].init_cuda_graph_state(
                max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
            )

    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
                forward_batch.batch_size,
                forward_batch.batch_size * self.topk,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

712
713
714
    def init_forward_metadata_replay_cuda_graph(
        self, forward_batch: ForwardBatch, bs: int
    ):
715
716
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
717
                bs,
718
719
720
721
722
723
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                seq_lens_sum=-1,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
724
                seq_lens_cpu=None,
725
726
727
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)