Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5a33c3aa
"torchvision/vscode:/vscode.git/clone" did not exist on "2b3a1b6dfc50a4daf9c9c5cf76606b4cccc61892"
Unverified
Commit
5a33c3aa
authored
Oct 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 14, 2025
Browse files
Optimize Triton Draft Backend (#11556)
parent
9767a1e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
28 deletions
+51
-28
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+51
-28
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
5a33c3aa
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch
import
triton
import
triton
...
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
...
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
get_device_core_count
,
get_device_core_count
,
...
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
max_bs
:
int
,
max_bs
:
int
,
max_num_tokens
:
int
,
max_num_tokens
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
cuda_graph_num_kv_splits_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_num_tokens
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
(
max_num_tokens
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
...
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
if
cuda_graph_num_kv_splits_buf
is
None
:
)
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_num_kv_splits
=
cuda_graph_num_kv_splits_buf
if
kv_indices_buf
is
None
:
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_num_tokens
*
self
.
max_context_len
),
(
max_num_tokens
*
self
.
max_context_len
),
...
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
else
:
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
assert
False
,
"Multi-step cuda graph init is not done here."
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
...
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
...
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
topk
:
int
,
topk
:
int
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
from
sglang.srt.speculative.spec_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
(
...
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
...
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
self
.
attn_backends
=
[]
self
.
attn_backends
:
List
[
TritonAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
TritonAttnBackend
(
TritonAttnBackend
(
...
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
...
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
self
.
page_size
=
model_runner
.
server_args
.
page_size
self
.
page_size
=
model_runner
.
server_args
.
page_size
def
common_template
(
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
Optional
[
torch
.
Tensor
],
call_fn
:
int
,
):
):
if
kv_indices_buffer
is
None
:
kv_indices_buffer
=
self
.
cuda_graph_kv_indices
num_seqs
=
forward_batch
.
batch_size
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
](
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
...
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
self
.
page_size
,
self
.
page_size
,
)
)
if
call_fn
is
None
:
return
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
...
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
...
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_num_tokens
,),
self
.
attn_backends
[
0
].
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
],
cuda_graph_num_kv_splits_buf
=
self
.
cuda_graph_num_kv_splits
,
)
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
...
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
...
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
)
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
None
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
def
call_fn
(
i
,
forward_batch
):
self
.
common_template
(
forward_batch
,
None
,
None
)
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
# NOTE: Multi-step's attention backends use the slice of
forward_batch
.
req_pool_indices
,
# - kv_indptr buffer (cuda graph and non-cuda graph)
forward_batch
.
seq_lens
,
# - kv_indices buffer (cuda graph only)
seq_lens_sum
=-
1
,
# So we don't need to assign the KV indices inside the attention backend.
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
# Compute num_kv_splits only once
spec_info
=
forward_batch
.
spec_info
,
num_token
=
forward_batch
.
batch_size
*
self
.
topk
se
q_lens_cpu
=
None
,
se
lf
.
attn_backends
[
-
1
].
get_num_kv_splits
(
)
self
.
attn_backends
[
-
1
].
cuda_graph_num_kv_splits
[:
num_token
],
forward_batch
.
seq_lens
[:
bs
],
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
)
@
triton
.
jit
@
triton
.
jit
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment