triton_backend.py 9.86 KB
Newer Older
1
2
from __future__ import annotations

3
from typing import TYPE_CHECKING, Optional
4
5
6
7

import torch

from sglang.srt.layers.attention import AttentionBackend
8
9
10
from sglang.srt.layers.attention.flashinfer_backend import (
    create_flashinfer_kv_indices_triton,
)
11
from sglang.srt.layers.dp_attention import get_attention_tp_size
12
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
13
14

if TYPE_CHECKING:
15
    from sglang.srt.layers.radix_attention import RadixAttention
16
    from sglang.srt.model_executor.model_runner import ModelRunner
17
    from sglang.srt.speculative.spec_info import SpecInfo
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class TritonAttnBackend(AttentionBackend):
    def __init__(self, model_runner: ModelRunner):
        # Lazy import to avoid the initialization of cuda context
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd,
        )

        super().__init__()

        self.decode_attention_fwd = decode_attention_fwd
        self.extend_attention_fwd = extend_attention_fwd
Ke Bao's avatar
Ke Bao committed
34

35
36
37
38
39
        max_bs = model_runner.req_to_token_pool.size
        self.kv_indptr = torch.zeros(
            (max_bs + 1,), dtype=torch.int32, device=model_runner.device
        )
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
40
41
42
        self.qo_indptr = torch.zeros(
            (max_bs + 1,), dtype=torch.int32, device=model_runner.device
        )
43

44
45
46
        self.num_head = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
47

48
49
50
        self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
        self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

51
52
53
54
        self.forward_metadata = None

        self.cuda_graph_max_seq_len = model_runner.model_config.context_len

55
56
        self.device = model_runner.device

57
58
59
    def init_forward_metadata(self, forward_batch: ForwardBatch):
        """Init auxiliary variables for triton attention backend."""

60
61
62
        bs = forward_batch.batch_size
        kv_indptr = self.kv_indptr

63
64
        if forward_batch.forward_mode.is_decode():
            attn_logits = torch.empty(
65
66
67
68
69
70
71
                (
                    forward_batch.batch_size,
                    self.num_head,
                    self.num_kv_splits,
                    self.v_head_dim + 1,
                ),
                dtype=torch.float32,
72
                device=self.device,
73
74
75
            )

            max_extend_len = None
76
77
78
79

            kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
            kv_indices = torch.empty(
80
                forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
81
82
            )
            create_flashinfer_kv_indices_triton[(bs,)](
83
                self.req_to_token,
84
85
86
87
88
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                kv_indptr,
                None,
                kv_indices,
89
                self.req_to_token.stride(0),
90
91
            )

92
93
            qo_indptr = None
            custom_mask = None
94
            mask_offsets = None
95
        else:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            kv_indptr[1 : bs + 1] = torch.cumsum(
                forward_batch.extend_prefix_lens, dim=0
            )
            kv_indptr = kv_indptr[: bs + 1]
            kv_indices = torch.empty(
                forward_batch.extend_prefix_lens.sum().item(),
                dtype=torch.int32,
                device=self.device,
            )
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                forward_batch.req_pool_indices,
                forward_batch.extend_prefix_lens,
                kv_indptr,
                None,
                kv_indices,
                self.req_to_token.stride(0),
            )

            qo_indptr = self.qo_indptr
            qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
119
            mask_offsets = None
120

121
            attn_logits = None
122
            max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
123

124
125
126
127
128
129
130
        self.forward_metadata = (
            attn_logits,
            max_extend_len,
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
131
            mask_offsets,
132
        )
133
134
135
136
137

    def init_cuda_graph_state(self, max_bs: int):
        self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len

        self.cuda_graph_start_loc = torch.zeros(
138
            (max_bs,), dtype=torch.int32, device=self.device
139
140
        )
        self.cuda_graph_attn_logits = torch.empty(
141
142
            (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
            dtype=torch.float32,
143
144
145
146
147
148
            device=self.device,
        )
        self.cuda_graph_kv_indices = torch.zeros(
            (max_bs * self.cuda_graph_max_seq_len),
            dtype=torch.int32,
            device=self.device,
149
150
151
        )

    def init_forward_metadata_capture_cuda_graph(
152
153
        self,
        bs: int,
154
        num_tokens: int,
155
156
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
157
158
159
        encoder_lens: Optional[torch.Tensor],
        forward_mode: ForwardMode,
        spec_info: Optional[SpecInfo],
160
    ):
161
162
163
164
        assert encoder_lens is None, "Not supported"
        assert forward_mode.is_decode(), "Not supported"
        assert spec_info is None, "Not supported"

165
166
167
168
169
170
171
172
173
174
175
176
177
178
        kv_indptr = self.kv_indptr
        kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
        kv_indptr = kv_indptr[: bs + 1]
        kv_indices = self.cuda_graph_kv_indices
        create_flashinfer_kv_indices_triton[(bs,)](
            self.req_to_token,
            req_pool_indices,
            seq_lens,
            kv_indptr,
            None,
            kv_indices,
            self.req_to_token.stride(0),
        )

179
180
181
        self.forward_metadata = (
            self.cuda_graph_attn_logits,
            None,
182
183
            kv_indptr,
            kv_indices,
184
185
            None,
            None,
186
            None,
187
188
189
        )

    def init_forward_metadata_replay_cuda_graph(
190
191
192
193
194
        self,
        bs: int,
        req_pool_indices: torch.Tensor,
        seq_lens: torch.Tensor,
        seq_lens_sum: int,
195
196
197
        encoder_lens: Optional[torch.Tensor],
        forward_mode: ForwardMode,
        spec_info: Optional[SpecInfo],
198
    ):
199
        # NOTE: encoder_lens expected to be zeros or None
200
201
202
        self.cuda_graph_start_loc.zero_()
        self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
        kv_indptr = self.kv_indptr
        kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
        kv_indptr = kv_indptr[: bs + 1]
        kv_indices = self.cuda_graph_kv_indices
        create_flashinfer_kv_indices_triton[(bs,)](
            self.req_to_token,
            req_pool_indices[:bs],
            seq_lens[:bs],
            kv_indptr,
            None,
            kv_indices,
            self.req_to_token.stride(0),
        )

217
218
219
    def get_cuda_graph_seq_len_fill_value(self):
        return 1

220
    def forward_extend(
221
        self,
222
223
224
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
225
226
227
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
228
    ):
229
230
231
232
233
234
        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

235
236
237
238
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
239

240
241
242
243
244
245
246
247
248
        (
            _,
            max_extend_len,
            kv_indptr,
            kv_indices,
            qo_indptr,
            custom_mask,
            mask_offsets,
        ) = self.forward_metadata
249
250
251
252
253
254
255
        self.extend_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            k.contiguous(),
            v.contiguous(),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
256
257
258
            qo_indptr,
            kv_indptr,
            kv_indices,
259
260
            custom_mask,
            mask_offsets,
261
262
263
264
265
266
            max_extend_len,
            layer.scaling,
            layer.logit_cap,
        )
        return o

267
    def forward_decode(
268
        self,
269
270
271
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
272
273
274
        layer: RadixAttention,
        forward_batch: ForwardBatch,
        save_kv_cache=True,
275
    ):
276
277
278
279
280
281
282
283
284
285
        # During torch.compile, there is a bug in rotary_emb that causes the
        # output value to have a 3D tensor shape. This reshapes the output correctly.
        q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

        # TODO: reuse the buffer across layers
        if layer.qk_head_dim != layer.v_head_dim:
            o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
        else:
            o = torch.empty_like(q)

286
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
287

288
289
290
291
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, forward_batch.out_cache_loc, k, v
            )
292
293
294
295
296
297

        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
298
299
            kv_indptr,
            kv_indices,
300
            attn_logits,
301
            self.num_kv_splits,
302
303
304
305
            layer.scaling,
            layer.logit_cap,
        )
        return o