Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5a33c3aa
"magic_pdf/vscode:/vscode.git/clone" did not exist on "3271cf75d3f01895dcada70830a73edf95ef23a2"
Unverified
Commit
5a33c3aa
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Optimize Triton Draft Backend (#11556)
parent
9767a1e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
28 deletions
+51
-28
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+51
-28
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
5a33c3aa
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
triton
...
...
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_device_core_count
,
...
...
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
max_bs
:
int
,
max_num_tokens
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
cuda_graph_num_kv_splits_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
...
...
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
cuda_graph_num_kv_splits_buf
is
None
:
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_num_kv_splits
=
cuda_graph_num_kv_splits_buf
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
max_context_len
),
...
...
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
assert
False
,
"Multi-step cuda graph init is not done here."
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
elif
forward_mode
.
is_target_verify
():
...
...
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.spec_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
...
...
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
self
.
attn_backends
:
List
[
TritonAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
TritonAttnBackend
(
...
...
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
self
.
page_size
=
model_runner
.
server_args
.
page_size
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
Optional
[
torch
.
Tensor
],
call_fn
:
int
,
):
if
kv_indices_buffer
is
None
:
kv_indices_buffer
=
self
.
cuda_graph_kv_indices
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
forward_batch
.
req_pool_indices
,
...
...
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
self
.
page_size
,
)
if
call_fn
is
None
:
return
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
...
...
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int64
,
device
=
self
.
device
,
)
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
attn_backends
[
0
].
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
],
cuda_graph_num_kv_splits_buf
=
self
.
cuda_graph_num_kv_splits
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
...
...
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
None
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
se
q_lens_cpu
=
None
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
None
,
None
)
# NOTE: Multi-step's attention backends use the slice of
# - kv_indptr buffer (cuda graph and non-cuda graph)
# - kv_indices buffer (cuda graph only)
# So we don't need to assign the KV indices inside the attention backend.
# Compute num_kv_splits only once
num_token
=
forward_batch
.
batch_size
*
self
.
topk
se
lf
.
attn_backends
[
-
1
].
get_num_kv_splits
(
self
.
attn_backends
[
-
1
].
cuda_graph_num_kv_splits
[:
num_token
],
forward_batch
.
seq_lens
[:
bs
],
)
@
triton
.
jit
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment