Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5a33c3aa
Unverified
Commit
5a33c3aa
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Optimize Triton Draft Backend (#11556)
parent
9767a1e4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
28 deletions
+51
-28
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+51
-28
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
5a33c3aa
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
import
triton
import
triton
...
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
...
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
get_device_core_count
,
get_device_core_count
,
...
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
max_bs
:
int
,
max_bs
:
int
,
max_num_tokens
:
int
,
max_num_tokens
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
cuda_graph_num_kv_splits_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
(
max_num_tokens
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
...
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
if
cuda_graph_num_kv_splits_buf
is
None
:
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_num_tokens
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
)
else
:
self
.
cuda_graph_num_kv_splits
=
cuda_graph_num_kv_splits_buf
if
kv_indices_buf
is
None
:
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
max_context_len
),
(
max_num_tokens
*
self
.
max_context_len
),
...
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
else
:
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
assert
False
,
"Multi-step cuda graph init is not done here."
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
...
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
...
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
topk
:
int
,
topk
:
int
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
from
sglang.srt.speculative.spec_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
(
...
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
...
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
self
.
attn_backends
=
[]
self
.
attn_backends
:
List
[
TritonAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
TritonAttnBackend
(
TritonAttnBackend
(
...
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
...
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
self
.
page_size
=
model_runner
.
server_args
.
page_size
self
.
page_size
=
model_runner
.
server_args
.
page_size
def
common_template
(
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
Optional
[
torch
.
Tensor
],
call_fn
:
int
,
):
):
if
kv_indices_buffer
is
None
:
kv_indices_buffer
=
self
.
cuda_graph_kv_indices
num_seqs
=
forward_batch
.
batch_size
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
](
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
...
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
self
.
page_size
,
self
.
page_size
,
)
)
if
call_fn
is
None
:
return
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
...
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
...
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
attn_backends
[
0
].
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
],
cuda_graph_num_kv_splits_buf
=
self
.
cuda_graph_num_kv_splits
,
)
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
...
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
...
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
)
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
None
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
def
call_fn
(
i
,
forward_batch
):
self
.
common_template
(
forward_batch
,
None
,
None
)
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
# NOTE: Multi-step's attention backends use the slice of
forward_batch
.
req_pool_indices
,
# - kv_indptr buffer (cuda graph and non-cuda graph)
forward_batch
.
seq_lens
,
# - kv_indices buffer (cuda graph only)
seq_lens_sum
=-
1
,
# So we don't need to assign the KV indices inside the attention backend.
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
None
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
# Compute num_kv_splits only once
num_token
=
forward_batch
.
batch_size
*
self
.
topk
self
.
attn_backends
[
-
1
].
get_num_kv_splits
(
self
.
attn_backends
[
-
1
].
cuda_graph_num_kv_splits
[:
num_token
],
forward_batch
.
seq_lens
[:
bs
],
)
@
triton
.
jit
@
triton
.
jit
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment