triton_backend.py 17.2 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import TYPE_CHECKING, Optional
4
5

import torch
6
import triton
7
8

from sglang.srt.layers.attention import AttentionBackend
9
10
11
from sglang.srt.layers.attention.flashinfer_backend import (
    create_flashinfer_kv_indices_triton,
)
12
from sglang.srt.layers.dp_attention import get_attention_tp_size
13
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
15

if TYPE_CHECKING:
16
    from sglang.srt.layers.radix_attention import RadixAttention
17
    from sglang.srt.model_executor.model_runner import ModelRunner
18
    from sglang.srt.speculative.spec_info import SpecInfo
19
20
21


class TritonAttnBackend(AttentionBackend):
22
23
24
25
26
27
    def __init__(
        self,
        model_runner: ModelRunner,
        skip_prefill: bool = False,
        kv_indptr_buf: Optional[torch.Tensor] = None,
    ):
28
29
30
31
32
33
34
35
36
37
38
39
        # Lazy import to avoid the initialization of cuda context
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

        self.decode_attention_fwd = decode_attention_fwd
        self.extend_attention_fwd = extend_attention_fwd
Ke Bao's avatar
Ke Bao committed
40

41
        max_bs = model_runner.req_to_token_pool.size
42
43
44
45
46
47
48
49

        if kv_indptr_buf is None:
            self.kv_indptr = torch.zeros(
                (max_bs + 1,), dtype=torch.int32, device=model_runner.device
            )
        else:
            self.kv_indptr = kv_indptr_buf

50
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
51
52
53
        self.qo_indptr = torch.zeros(
            (max_bs + 1,), dtype=torch.int32, device=model_runner.device
        )
54

55
56
57
58
59
60
        self.mask_indptr = torch.zeros(
            (max_bs + 1,), dtype=torch.int64, device=model_runner.device
        )

        self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens

61
62
63
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
64

65
66
67
        self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

68
69
        self.forward_metadata = None

70
        self.max_context_len = model_runner.model_config.context_len
71

72
73
        self.device = model_runner.device

74
75
76
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

77
78
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        spec_info = forward_batch.spec_info

        if forward_batch.forward_mode.is_decode_or_idle():
            if spec_info is None:
                kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
                kv_indptr = kv_indptr[: bs + 1]
                kv_indices = torch.zeros(
                    forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
                )
                create_flashinfer_kv_indices_triton[(bs,)](
                    self.req_to_token,
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    kv_indptr,
                    None,
                    kv_indices,
                    self.req_to_token.stride(0),
                )
            else:
                kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
                bs = kv_indptr.shape[0] - 1

            attn_logits = torch.zeros(
102
                (
103
                    bs,
104
105
106
107
108
                    self.num_head,
                    self.num_kv_splits,
                    self.v_head_dim + 1,
                ),
                dtype=torch.float32,
109
                device=self.device,
110
111
            )

112
113
114
            qo_indptr = None
            custom_mask = None
            mask_indptr = None
115
            max_extend_len = None
116
117
118
119
120
121
122
123
124
125
        elif forward_batch.forward_mode.is_target_verify():
            bs = len(forward_batch.req_pool_indices)
            qo_indptr = torch.arange(
                0,
                (1 + bs) * self.num_draft_tokens,
                step=self.num_draft_tokens,
                dtype=torch.int32,
                device=self.device,
            )
            # Different with flashinfer kv_indptr and kv_indices construction
126
127
            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
128
129
            kv_indices = torch.zeros(
                kv_indptr[-1], dtype=torch.int32, device=self.device
130
131
            )
            create_flashinfer_kv_indices_triton[(bs,)](
132
                self.req_to_token,
133
134
135
136
137
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
138
                self.req_to_token.stride(0),
139
140
            )

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            custom_mask = spec_info.custom_mask
            seq_mask_len = self.num_draft_tokens * (
                forward_batch.seq_lens + self.num_draft_tokens
            )
            mask_indptr = self.mask_indptr
            mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
            mask_indptr = mask_indptr[: bs + 1]
            max_extend_len = self.num_draft_tokens
            attn_logits = None
        elif forward_batch.forward_mode.is_draft_extend():
            kv_indices, kv_indptr, qo_indptr, custom_mask = (
                spec_info.generate_attn_arg_prefill(
                    forward_batch.req_pool_indices,
                    forward_batch.seq_lens,
                    self.req_to_token,
                )
            )
            mask_indptr = None
            max_extend_len = torch.max(spec_info.accept_length).item()
            attn_logits = None
161
        else:
162
163
164
165
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
166
            kv_indices = torch.zeros(
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                forward_batch.extend_prefix_lens.sum().item(),
                dtype=torch.int32,
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
185
            mask_indptr = None
186
            attn_logits = None
187
            max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
188

189
190
191
192
193
194
195
        self.forward_metadata = (
            attn_logits,
            max_extend_len,
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
196
            mask_indptr,
197
        )
198
199

    def init_cuda_graph_state(self, max_bs: int):
200
        self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len
201
202

        self.cuda_graph_start_loc = torch.zeros(
203
            (max_bs,), dtype=torch.int32, device=self.device
204
        )
205
        self.cuda_graph_attn_logits = torch.zeros(
206
207
            (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
            dtype=torch.float32,
208
209
210
            device=self.device,
        )
        self.cuda_graph_kv_indices = torch.zeros(
211
            (max_bs * self.max_context_len),
212
213
            dtype=torch.int32,
            device=self.device,
214
215
216
        )

    def init_forward_metadata_capture_cuda_graph(
217
218
        self,
        bs: int,
219
        num_tokens: int,
220
221
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
222
223
224
        encoder_lens: Optional[torch.Tensor],
        forward_mode: ForwardMode,
        spec_info: Optional[SpecInfo],
225
    ):
226
227
228
229
        assert encoder_lens is None, "Not supported"
        assert forward_mode.is_decode(), "Not supported"
        assert spec_info is None, "Not supported"

230
231
232
233
234
235
236
237
238
239
240
241
242
243
        kv_indptr = self.kv_indptr
        kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
        kv_indptr = kv_indptr[: bs + 1]
        kv_indices = self.cuda_graph_kv_indices
        create_flashinfer_kv_indices_triton[(bs,)](
            self.req_to_token,
            req_pool_indices,
            seq_lens,
            kv_indptr,
            None,
            kv_indices,
            self.req_to_token.stride(0),
        )

244
245
246
        self.forward_metadata = (
            self.cuda_graph_attn_logits,
            None,
247
248
            kv_indptr,
            kv_indices,
249
250
            None,
            None,
251
            None,
252
253
254
        )

    def init_forward_metadata_replay_cuda_graph(
255
256
257
258
259
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
260
261
262
        encoder_lens: Optional[torch.Tensor],
        forward_mode: ForwardMode,
        spec_info: Optional[SpecInfo],
263
    ):
264
        # NOTE: encoder_lens expected to be zeros or None
265
266
267
        self.cuda_graph_start_loc.zero_()
        self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

268
269
270
271
272
273
274
275
276
277
278
279
280
281
        kv_indptr = self.kv_indptr
        kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
        kv_indptr = kv_indptr[: bs + 1]
        kv_indices = self.cuda_graph_kv_indices
        create_flashinfer_kv_indices_triton[(bs,)](
            self.req_to_token,
            req_pool_indices[:bs],
            seq_lens[:bs],
            kv_indptr,
            None,
            kv_indices,
            self.req_to_token.stride(0),
        )

282
283
284
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

285
    def forward_extend(
286
        self,
287
288
289
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
290
291
292
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
293
    ):
294
295
296
297
298
299
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

300
301
302
303
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
304

305
306
307
308
309
310
311
        (
            _,
            max_extend_len,
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
312
            mask_indptr,
313
        ) = self.forward_metadata
314

315
316
317
318
319
320
321
        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
322
323
324
            qo_indptr,
            kv_indptr,
            kv_indices,
325
            custom_mask,
326
            mask_indptr,
327
328
329
330
331
332
            max_extend_len,
            layer.scaling,
            layer.logit_cap,
        )
        return o

333
    def forward_decode(
334
        self,
335
336
337
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
338
339
340
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
341
    ):
342
343
344
345
346
347
348
349
350
351
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

352
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
353

354
355
356
357
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
358
359
360
361
362
363

        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
364
365
            kv_indptr,
            kv_indices,
366
            attn_logits,
367
            self.num_kv_splits,
368
369
370
371
            layer.scaling,
            layer.logit_cap,
        )
        return o
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504


class TritonMultiStepDraftBackend:
    """
    Wrap multiple triton attention backends as one for multiple consecutive
    draft decoding steps.
    """

    def __init__(
        self,
        model_runner: ModelRunner,
        topk: int,
        speculative_num_steps: int,
    ):
        from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices

        self.topk = topk
        self.speculative_num_steps = speculative_num_steps
        self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
        max_bs = model_runner.req_to_token_pool.size
        self.kv_indptr = torch.zeros(
            (
                self.speculative_num_steps,
                max_bs + 1,
            ),
            dtype=torch.int32,
            device=model_runner.device,
        )
        self.attn_backends = []
        for i in range(self.speculative_num_steps):
            self.attn_backends.append(
                TritonAttnBackend(
                    model_runner,
                    skip_prefill=True,
                    kv_indptr_buf=self.kv_indptr[i],
                )
            )
        self.max_context_len = self.attn_backends[0].max_context_len
        # Cached variables for generate_draft_decode_kv_indices
        self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]

    def common_template(
        self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
    ):
        num_seqs = forward_batch.batch_size
        bs = self.topk * num_seqs
        seq_lens_sum = forward_batch.seq_lens_sum

        self.generate_draft_decode_kv_indices[
            (self.speculative_num_steps, num_seqs, self.topk)
        ](
            forward_batch.req_pool_indices,
            forward_batch.req_to_token_pool.req_to_token,
            forward_batch.seq_lens,
            kv_indices_buffer,
            self.kv_indptr,
            forward_batch.positions,
            num_seqs,
            self.topk,
            self.pool_len,
            kv_indices_buffer.shape[1],
            self.kv_indptr.shape[1],
            triton.next_power_of_2(num_seqs),
            triton.next_power_of_2(self.speculative_num_steps),
            triton.next_power_of_2(bs),
        )

        for i in range(self.speculative_num_steps):
            forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
            forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
                : seq_lens_sum * self.topk + bs * (i + 1)
            ]
            call_fn(i, forward_batch)

    def init_forward_metadata(self, forward_batch: ForwardBatch):
        kv_indices = torch.zeros(
            (
                self.speculative_num_steps,
                forward_batch.batch_size * self.topk * self.max_context_len,
            ),
            dtype=torch.int32,
            device="cuda",
        )

        def call_fn(i, forward_batch):
            forward_batch.spec_info.kv_indptr = (
                forward_batch.spec_info.kv_indptr.clone()
            )
            forward_batch.spec_info.kv_indices = (
                forward_batch.spec_info.kv_indices.clone()
            )
            self.attn_backends[i].init_forward_metadata(forward_batch)

        self.common_template(forward_batch, kv_indices, call_fn)

    def init_cuda_graph_state(self, max_bs: int):
        self.cuda_graph_kv_indices = torch.zeros(
            (self.speculative_num_steps, max_bs * self.max_context_len),
            dtype=torch.int32,
            device="cuda",
        )
        for i in range(self.speculative_num_steps):
            self.attn_backends[i].init_cuda_graph_state(
                max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
            )

    def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
                forward_batch.batch_size,
                forward_batch.batch_size * self.topk,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)

    def init_forward_metadata_replay_cuda_graph(self, forward_batch):
        def call_fn(i, forward_batch):
            self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
                forward_batch.batch_size,
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                seq_lens_sum=-1,
                encoder_lens=None,
                forward_mode=ForwardMode.DECODE,
                spec_info=forward_batch.spec_info,
            )

        self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)