Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36a4cad7
Unverified
Commit
36a4cad7
authored
Oct 23, 2025
by
Qiaolin Yu
Committed by
GitHub
Oct 23, 2025
Browse files
Support overlap-spec-v2 with trtllm_mla attention backend (#11821)
parent
65d376b4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
9 deletions
+12
-9
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+12
-9
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
36a4cad7
...
@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
...
@@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
from
sglang.srt.utils.common
import
cached_triton_kernel
if
is_flashinfer_available
():
if
is_flashinfer_available
():
import
flashinfer
import
flashinfer
...
@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
...
@@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
TRTLLM_BLOCK_CONSTRAINT
=
128
TRTLLM_BLOCK_CONSTRAINT
=
128
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"BLOCK_SIZE"
]))
@
triton
.
jit
@
triton
.
jit
def
pad_draft_extend_query_kernel
(
def
pad_draft_extend_query_kernel
(
q_ptr
,
# Input query tensor [total_seq_len, num_heads, head_dim]
q_ptr
,
# Input query tensor [total_seq_len, num_heads, head_dim]
...
@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel(
...
@@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel(
)
)
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"BLOCK_SIZE"
]))
@
triton
.
jit
@
triton
.
jit
def
unpad_draft_extend_output_kernel
(
def
unpad_draft_extend_output_kernel
(
raw_out_ptr
,
# Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
raw_out_ptr
,
# Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
...
@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if
(
if
(
not
forward_mode
.
is_decode_or_idle
()
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
()
and
not
forward_mode
.
is_target_verify
()
and
not
forward_mode
.
is_draft_extend
()
and
not
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
):
return
super
().
init_forward_metadata_capture_cuda_graph
(
return
super
().
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
...
@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices
,
block_kv_indices
,
max_seq_len_val
,
max_seq_len_val
,
)
)
if
forward_mode
.
is_draft_extend
():
if
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
num_tokens_per_bs
=
num_tokens
//
bs
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
+
1
metadata
.
max_seq_len_q
=
num_tokens_per_bs
+
1
metadata
.
sum_seq_lens_q
=
num_tokens_per_bs
*
bs
metadata
.
sum_seq_lens_q
=
num_tokens_per_bs
*
bs
...
@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if
(
if
(
not
forward_mode
.
is_decode_or_idle
()
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
()
and
not
forward_mode
.
is_target_verify
()
and
not
forward_mode
.
is_draft_extend
()
and
not
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
):
return
super
().
init_forward_metadata_replay_cuda_graph
(
return
super
().
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
...
@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
if
forward_mode
.
is_draft_extend
():
if
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
accept_length
=
spec_info
.
accept_length
[:
bs
]
accept_length
=
spec_info
.
accept_length
[:
bs
]
if
spec_info
.
accept_length_cpu
:
if
spec_info
.
accept_length_cpu
:
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
[:
bs
])
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
[:
bs
])
...
@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if
(
if
(
forward_batch
.
forward_mode
.
is_extend
()
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
):
if
self
.
disable_chunked_prefix_cache
:
if
self
.
disable_chunked_prefix_cache
:
super
().
init_forward_metadata
(
forward_batch
)
super
().
init_forward_metadata
(
forward_batch
)
...
@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
elif
(
elif
(
forward_batch
.
forward_mode
.
is_decode_or_idle
()
forward_batch
.
forward_mode
.
is_decode_or_idle
()
or
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
...
@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
block_kv_indices
,
max_seq_len_val
block_kv_indices
,
max_seq_len_val
)
)
if
forward_batch
.
forward_mode
.
is_draft_extend
():
if
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
sum_seq_lens_q
=
sum
(
forward_batch
.
extend_seq_lens_cpu
)
sum_seq_lens_q
=
sum
(
forward_batch
.
extend_seq_lens_cpu
)
...
@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if
(
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
):
metadata
=
(
metadata
=
(
getattr
(
forward_batch
,
"decode_trtllm_mla_metadata"
,
None
)
getattr
(
forward_batch
,
"decode_trtllm_mla_metadata"
,
None
)
...
@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Reshape output directly without slicing
# Reshape output directly without slicing
if
forward_batch
.
forward_mode
.
is_draft_extend
():
if
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
raw_out
=
self
.
unpad_draft_extend_output
(
raw_out
=
self
.
unpad_draft_extend_output
(
raw_out
,
raw_out
,
metadata
.
cu_seqlens_q
,
metadata
.
cu_seqlens_q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment