Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d5b6e50f
Unverified
Commit
d5b6e50f
authored
Oct 31, 2025
by
yinghui
Committed by
GitHub
Oct 31, 2025
Browse files
perf: trtllm mla performance minor improvements (#12435)
parent
9632e48f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
65 deletions
+70
-65
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+70
-65
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
d5b6e50f
...
...
@@ -219,6 +219,7 @@ class TRTLLMMLADecodeMetadata:
sum_seq_lens_q
:
Optional
[
int
]
=
None
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_q
:
Optional
[
torch
.
Tensor
]
=
None
seq_lens_k
:
Optional
[
torch
.
Tensor
]
=
None
class
TRTLLMMLABackend
(
FlashInferMLAAttnBackend
):
...
...
@@ -404,8 +405,38 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info
,
)
metadata
=
TRTLLMMLADecodeMetadata
()
if
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
metadata
.
seq_lens_k
=
torch
.
zeros
(
(
bs
,),
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
metadata
.
seq_lens_k
.
copy_
(
seq_lens
.
to
(
dtype
=
torch
.
int32
))
elif
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
metadata
.
sum_seq_lens_q
=
num_tokens_per_bs
*
bs
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
metadata
.
seq_lens_q
=
torch
.
full
(
(
bs
,),
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
# NOTE(draft_extend seq_len handling):
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens
=
seq_lens
-
metadata
.
seq_lens_q
+
metadata
.
max_seq_len_q
metadata
.
seq_lens_k
=
torch
.
zeros
(
(
bs
,),
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
metadata
.
seq_lens_k
.
copy_
(
seq_lens
.
to
(
dtype
=
torch
.
int32
))
# Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay
...
...
@@ -423,24 +454,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE
=
self
.
page_size
,
)
metadata
=
TRTLLMMLADecodeMetadata
(
block_kv_indices
,
self
.
max_context_len
,
)
if
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
+
1
metadata
.
sum_seq_lens_q
=
num_tokens_per_bs
*
bs
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
metadata
.
seq_lens_q
=
torch
.
full
(
(
bs
,),
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
metadata
.
block_kv_indices
=
block_kv_indices
metadata
.
max_seq_len_k
=
self
.
max_context_len
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_decode_metadata
=
metadata
...
...
@@ -473,17 +489,17 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu
,
)
if
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
del
seq_lens_sum
# not handle "num_draft_tokens" but we do not need it
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
if
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
if
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
[:
bs
]
+
self
.
num_draft_tokens
metadata
.
seq_lens_k
.
copy_
(
seq_lens
.
to
(
dtype
=
torch
.
int32
))
del
seq_lens_sum
# not handle "num_draft_tokens" but we do not need it
elif
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
accept_length
=
spec_info
.
accept_length
[:
bs
]
if
spec_info
.
accept_length_cpu
:
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
[:
bs
])
metadata
.
sum_seq_lens_q
=
sum
(
spec_info
.
accept_length_cpu
[:
bs
])
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
[:
bs
])
+
1
metadata
.
sum_seq_lens_q
=
sum
(
spec_info
.
accept_length_cpu
[:
bs
])
+
bs
else
:
metadata
.
max_seq_len_q
=
1
metadata
.
sum_seq_lens_q
=
bs
...
...
@@ -491,12 +507,15 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
metadata
.
seq_lens_q
.
copy_
(
accept_length
)
# see NOTE(draft_extend seq_len handling)
seq_lens
=
seq_lens
[:
bs
]
-
metadata
.
seq_lens_q
+
metadata
.
max_seq_len_q
metadata
.
seq_lens_k
.
copy_
(
seq_lens
.
to
(
torch
.
int32
))
# Update block indices for new sequences.
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
]
,
seq_lens
,
None
,
metadata
.
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
...
...
@@ -538,7 +557,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
bs
=
forward_batch
.
batch_size
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
()
# Get maximum sequence length.
if
getattr
(
forward_batch
,
"seq_lens_cpu"
,
None
)
is
not
None
:
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
...
...
@@ -550,21 +569,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
if
forward_batch
.
forward_mode
.
is_target_verify
():
max_seq
=
max_seq
+
self
.
num_draft_tokens
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
max_seq
)
block_kv_indices
=
self
.
_create_block_kv_indices
(
bs
,
max_seqlen_pad
,
forward_batch
.
req_pool_indices
,
seq_lens
,
seq_lens
.
device
,
)
max_seq_len_val
=
int
(
max_seq
)
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
block_kv_indices
,
max_seq_len_val
)
if
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
self
.
forward_decode_metadata
.
seq_lens_k
=
seq_lens
elif
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
sum_seq_lens_q
=
sum
(
forward_batch
.
extend_seq_lens_cpu
)
...
...
@@ -575,11 +581,26 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
),
(
1
,
0
),
)
# see NOTE(draft_extend seq_len handling)
seq_lens
=
seq_lens
-
forward_batch
.
extend_seq_lens
+
max_seq_len_q
self
.
forward_decode_metadata
.
max_seq_len_q
=
max_seq_len_q
self
.
forward_decode_metadata
.
sum_seq_lens_q
=
sum_seq_lens_q
self
.
forward_decode_metadata
.
cu_seqlens_q
=
cu_seqlens_q
self
.
forward_decode_metadata
.
seq_lens_q
=
forward_batch
.
extend_seq_lens
self
.
forward_decode_metadata
.
seq_lens_k
=
seq_lens
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
max_seq
)
block_kv_indices
=
self
.
_create_block_kv_indices
(
bs
,
max_seqlen_pad
,
forward_batch
.
req_pool_indices
,
seq_lens
,
seq_lens
.
device
,
)
self
.
forward_decode_metadata
.
block_kv_indices
=
block_kv_indices
self
.
forward_decode_metadata
.
max_seq_len_k
=
int
(
max_seq
)
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_decode_metadata
else
:
...
...
@@ -899,18 +920,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
q
=
_concat_mla_absorb_q_general
(
q_nope
,
q_rope_reshaped
)
else
:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
k_rope
is
not
None
:
k
=
torch
.
cat
([
k
,
k_rope
],
dim
=-
1
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
...
...
@@ -936,23 +948,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
if
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens
=
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
+
forward_batch
.
spec_info
.
draft_token_num
)
max_seq_len
=
(
metadata
.
max_seq_len_k
+
forward_batch
.
spec_info
.
draft_token_num
)
else
:
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens
=
(
forward_batch
.
seq_lens
-
metadata
.
seq_lens_q
+
metadata
.
max_seq_len_q
).
to
(
torch
.
int32
)
max_seq_len
=
metadata
.
max_seq_len_k
+
metadata
.
max_seq_len_q
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if
self
.
padded_q_buffer
is
not
None
:
...
...
@@ -986,7 +985,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
block_tables
=
metadata
.
block_kv_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
metadata
.
seq_lens
_k
,
max_seq_len
=
max_seq_len
,
bmm1_scale
=
bmm1_scale
,
)
...
...
@@ -1003,6 +1002,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output
=
raw_out
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
if
k_rope
is
not
None
:
k
=
torch
.
cat
([
k
,
k_rope
],
dim
=-
1
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
if
forward_batch
.
attn_attend_prefix_cache
:
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment