Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bebd0576
Unverified
Commit
bebd0576
authored
Sep 05, 2025
by
Elfie Guo
Committed by
GitHub
Sep 05, 2025
Browse files
Integrate trtllm ragged attention for prefill self-attention (#9801)
parent
f9836660
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
298 additions
and
42 deletions
+298
-42
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+16
-12
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+104
-24
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+9
-1
python/sglang/test/attention/test_trtllm_mla_backend.py
python/sglang/test/attention/test_trtllm_mla_backend.py
+169
-5
No files found.
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
bebd0576
...
...
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
def
update_wrapper
(
self
,
forward_batch
:
ForwardBatch
,
disable_flashinfer_ragged
:
bool
=
False
,
):
assert
forward_batch
.
num_prefix_chunks
is
not
None
num_prefix_chunks
=
forward_batch
.
num_prefix_chunks
...
...
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
causal
=
False
,
)
# ragged prefill
self
.
ragged_wrapper
.
begin_forward
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
qo_indptr
,
num_qo_heads
=
self
.
num_local_heads
,
num_kv_heads
=
self
.
num_local_heads
,
head_dim_qk
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
head_dim_vo
=
self
.
v_head_dim
,
q_data_type
=
self
.
q_data_type
,
causal
=
True
,
)
if
not
disable_flashinfer_ragged
:
self
.
ragged_wrapper
.
begin_forward
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
qo_indptr
,
num_qo_heads
=
self
.
num_local_heads
,
num_kv_heads
=
self
.
num_local_heads
,
head_dim_qk
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
head_dim_vo
=
self
.
v_head_dim
,
q_data_type
=
self
.
q_data_type
,
causal
=
True
,
)
def
forward
(
self
,
...
...
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
init_mha_chunk_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_mha_chunk_metadata
(
self
,
forward_batch
:
ForwardBatch
,
disable_flashinfer_ragged
:
bool
=
False
):
"""Init the metadata for a forward pass."""
self
.
mha_chunk_kv_cache
.
update_wrapper
(
forward_batch
)
self
.
mha_chunk_kv_cache
.
update_wrapper
(
forward_batch
,
disable_flashinfer_ragged
)
def
forward_extend
(
self
,
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
bebd0576
...
...
@@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128
global_zero_init_workspace_buffer
=
None
@
dataclass
class
TRTLLMMLAPrefillMetadata
:
"""Metadata for TRTLLM MLA prefill operations."""
max_seq_len
:
int
cum_seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
@
dataclass
class
TRTLLMMLADecodeMetadata
:
"""Metadata for TRTLLM MLA decode operations."""
...
...
@@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state
self
.
decode_cuda_graph_metadata
=
{}
self
.
decode_cuda_graph_kv_indices
=
None
self
.
forward_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
self
.
forward_prefill_metadata
:
Optional
[
TRTLLMMLAPrefillMetadata
]
=
None
self
.
forward_decode_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
def
_calc_padded_blocks
(
self
,
max_seq_len
:
int
)
->
int
:
"""
...
...
@@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val
,
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
forward_
decode_
metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
self
,
...
...
@@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes.
if
not
forward_batch
.
forward_mode
.
is_decode_or_idle
():
return
super
().
init_forward_metadata
(
forward_batch
)
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
seq_lens
=
forward_batch
.
seq_lens
-
forward_batch
.
extend_prefix_lens
cum_seq_lens_q
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
forward_batch
.
seq_lens
.
device
),
torch
.
cumsum
(
seq_lens
,
dim
=
0
),
)
).
int
()
max_seq_len
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
self
.
forward_prefill_metadata
=
TRTLLMMLAPrefillMetadata
(
max_seq_len
,
cum_seq_lens_q
,
seq_lens
,
)
elif
forward_batch
.
forward_mode
.
is_decode_or_idle
():
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
# Get maximum sequence length.
if
getattr
(
forward_batch
,
"seq_lens_cpu"
,
None
)
is
not
None
:
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
else
:
max_seq
=
forward_batch
.
seq_lens
.
max
().
item
()
# Get maximum sequence length.
if
getattr
(
forward_batch
,
"seq_lens_cpu"
,
None
)
is
not
None
:
max_seq
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
max_seq
)
block_kv_indices
=
self
.
_create_block_kv_indices
(
bs
,
max_seqlen_pad
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
.
device
,
)
max_seq_len_val
=
int
(
max_seq
)
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
workspace_buffer
,
block_kv_indices
,
max_seq_len_val
)
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_decode_metadata
else
:
max_seq
=
forward_batch
.
seq_lens
.
max
().
item
()
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
max_seq
)
block_kv_indices
=
self
.
_create_block_kv_indices
(
bs
,
max_seqlen_pad
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
.
device
,
)
return
super
().
init_forward_metadata
(
forward_batch
)
max_seq_len_val
=
int
(
max_seq
)
self
.
forward_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
workspace_buffer
,
block_kv_indices
,
max_seq_len_val
)
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_metadata
def
init_mha_chunk_metadata
(
self
,
forward_batch
:
ForwardBatch
):
super
().
init_mha_chunk_metadata
(
forward_batch
,
disable_flashinfer_ragged
=
True
)
def
quantize_and_rope_for_fp8
(
self
,
...
...
@@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Get metadata
metadata
=
(
getattr
(
forward_batch
,
"decode_trtllm_mla_metadata"
,
None
)
or
self
.
forward_metadata
or
self
.
forward_
decode_
metadata
)
# Scale computation for TRTLLM MLA kernel BMM1 operation:
...
...
@@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output
=
raw_out
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
if
not
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
output
=
flashinfer
.
prefill
.
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
workspace_buffer
,
seq_lens
=
self
.
forward_prefill_metadata
.
seq_lens
,
max_q_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
max_kv_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
bmm1_scale
=
layer
.
scaling
,
bmm2_scale
=
1.0
,
o_sf_scale
=
1.0
,
batch_size
=
forward_batch
.
batch_size
,
window_left
=-
1
,
cum_seq_lens_q
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
cum_seq_lens_kv
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
enable_pdl
=
False
,
is_causal
=
True
,
return_lse
=
forward_batch
.
mha_return_lse
,
)
else
:
# replace with trtllm ragged attention once accuracy is resolved.
output
=
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
return
output
class
TRTLLMMLAMultiStepDraftBackend
(
FlashInferMLAMultiStepDraftBackend
):
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
bebd0576
...
...
@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend
==
"flashinfer"
or
attention_backend
==
"fa3"
or
attention_backend
==
"flashmla"
or
attention_backend
==
"trtllm_mla"
or
attention_backend
==
"cutlass_mla"
):
# Use MHA with chunked KV cache when prefilling on long sequences.
...
...
@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
return
_dispatch_mla_subtype
()
elif
attention_backend
==
"trtllm_mla"
:
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
return
_dispatch_mla_subtype
()
elif
attention_backend
==
"aiter"
:
if
(
forward_batch
.
forward_mode
.
is_extend
()
...
...
python/sglang/test/attention/test_trtllm_mla_backend.py
View file @
bebd0576
...
...
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
"v_head_dim"
:
512
,
"num_kv_heads"
:
1
,
"layer_id"
:
0
,
"tp_q_head_num"
:
128
,
"tp_k_head_num"
:
128
,
"prefill_head_dim"
:
192
,
"prefill_v_head_dim"
:
128
,
}
ROPE_BASE
=
10000
...
...
@@ -92,7 +96,7 @@ TEST_CASES = {
"description"
:
"Medium-scale batch"
,
},
],
"
decode_
output_match"
:
[
"output_match"
:
[
{
"name"
:
"single_fp16"
,
"batch_size"
:
1
,
...
...
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
config
.
update
(
test_case
)
return
config
def
_create_model_components
(
self
,
config
):
def
_create_model_components
(
self
,
config
,
is_prefill
=
False
):
"""Create model runners, backends, and layer for testing."""
# Create model runners
model_runner_trtllm
=
MockModelRunner
(
config
)
...
...
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
trtllm_backend
=
TRTLLMMLABackend
(
model_runner_trtllm
)
reference_backend
=
FlashInferMLAAttnBackend
(
model_runner_reference
)
head_dim
=
(
config
[
"kv_lora_rank"
]
+
config
[
"qk_rope_head_dim"
]
if
not
is_prefill
else
config
[
"prefill_head_dim"
]
)
v_head_dim
=
(
config
[
"v_head_dim"
]
if
not
is_prefill
else
config
[
"prefill_v_head_dim"
]
)
# Create RadixAttention layer
layer
=
RadixAttention
(
num_heads
=
config
[
"num_attention_heads"
],
head_dim
=
config
[
"kv_lora_rank"
]
+
config
[
"qk_rope_
head_dim
"
]
,
head_dim
=
head_dim
,
scaling
=
model_runner_trtllm
.
model_config
.
scaling
,
num_kv_heads
=
config
[
"num_kv_heads"
],
layer_id
=
config
[
"layer_id"
],
v_head_dim
=
config
[
"
v_head_dim
"
]
,
v_head_dim
=
v_head_dim
,
prefix
=
"attn_mqa"
,
)
...
...
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
print
(
f
"
\n
Running decode output matching tests..."
)
for
test_case
in
TEST_CASES
[
"
decode_
output_match"
]:
for
test_case
in
TEST_CASES
[
"output_match"
]:
with
self
.
subTest
(
test_case
=
test_case
[
"name"
]):
print
(
f
" Testing
{
test_case
[
'name'
]
}
:
{
test_case
[
'description'
]
}
"
)
...
...
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
self
.
assertIsNotNone
(
metadata_3
.
block_kv_indices
)
self
.
assertEqual
(
metadata_3
.
block_kv_indices
.
shape
[
0
],
config
[
"batch_size"
])
def
test_prefill_output_match_self_attention
(
self
):
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
print
(
f
"
\n
Running prefill output tests..."
)
for
test_case
in
TEST_CASES
[
"output_match"
][:
2
]:
# Just a subset for speed
with
self
.
subTest
(
test_case
=
test_case
[
"name"
]):
print
(
f
"Prefill Testing
{
test_case
[
'name'
]
}
:
{
test_case
[
'description'
]
}
"
)
config
=
self
.
_merge_config
(
test_case
)
batch_size
=
config
[
"batch_size"
]
max_seq_len
=
config
[
"max_seq_len"
]
# Create components
(
model_runner_trtllm
,
model_runner_reference
,
trtllm_backend
,
reference_backend
,
layer
,
)
=
self
.
_create_model_components
(
config
,
is_prefill
=
True
)
# Prefill uses full sequences
seq_lens
=
torch
.
full
(
(
batch_size
,),
max_seq_len
,
device
=
config
[
"device"
]
)
def
_create_forward_batch_prefill
(
batch_size
,
seq_lens
,
extend_prefix_lens
,
backend
,
model_runner
,
config
,
):
"""Create a forward batch for the given backend."""
fb
=
ForwardBatch
(
batch_size
=
batch_size
,
input_ids
=
torch
.
randint
(
0
,
100
,
(
batch_size
,
1
),
device
=
config
[
"device"
]
),
out_cache_loc
=
torch
.
arange
(
batch_size
,
device
=
config
[
"device"
]),
seq_lens_sum
=
int
(
seq_lens
.
sum
().
item
()),
extend_prefix_lens
=
extend_prefix_lens
,
extend_prefix_lens_cpu
=
extend_prefix_lens
.
cpu
().
int
().
tolist
(),
extend_seq_lens_cpu
=
(
seq_lens
-
extend_prefix_lens
)
.
cpu
()
.
int
()
.
tolist
(),
forward_mode
=
ForwardMode
.
EXTEND
,
req_pool_indices
=
torch
.
arange
(
batch_size
,
device
=
config
[
"device"
]
),
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens
.
cpu
(),
attn_attend_prefix_cache
=
False
,
mha_return_lse
=
False
,
attn_backend
=
backend
,
)
fb
.
req_to_token_pool
=
model_runner
.
req_to_token_pool
fb
.
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
# Add position information for RoPE
fb
.
positions
=
torch
.
arange
(
batch_size
,
device
=
config
[
"device"
])
return
fb
# Create forward batches
fb_trtllm
=
_create_forward_batch_prefill
(
batch_size
,
seq_lens
.
clone
(),
torch
.
zeros
(
batch_size
,
device
=
config
[
"device"
],
dtype
=
torch
.
int32
),
trtllm_backend
,
model_runner_trtllm
,
config
,
)
fb_reference
=
_create_forward_batch_prefill
(
batch_size
,
seq_lens
.
clone
(),
torch
.
zeros
(
batch_size
,
device
=
config
[
"device"
],
dtype
=
torch
.
int32
),
reference_backend
,
model_runner_reference
,
config
,
)
# Initialize metadata for both backends
trtllm_backend
.
init_forward_metadata
(
fb_trtllm
)
reference_backend
.
init_forward_metadata
(
fb_reference
)
# Create Q, K, V tensors for prefill
torch
.
manual_seed
(
config
[
"seed_qkv"
])
def
_create_qkv_tensors_prefill
(
batch_size
,
seq_len
,
config
,
dtype_override
=
None
):
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
device
=
config
[
"device"
]
dtype
=
dtype_override
or
config
[
"dtype"
]
total_tokens
=
batch_size
*
seq_len
tp_q_head_num
=
config
[
"tp_q_head_num"
]
tp_k_head_num
=
config
[
"tp_k_head_num"
]
head_dim
=
config
[
"prefill_head_dim"
]
v_head_dim
=
config
[
"prefill_v_head_dim"
]
q
=
torch
.
randn
(
(
total_tokens
,
tp_q_head_num
*
head_dim
),
dtype
=
dtype
,
device
=
device
,
)
k
=
torch
.
randn
(
(
total_tokens
,
tp_k_head_num
*
head_dim
),
dtype
=
dtype
,
device
=
device
,
)
v
=
torch
.
randn
(
(
total_tokens
,
tp_k_head_num
*
v_head_dim
),
dtype
=
dtype
,
device
=
device
,
)
# Reshape as requested
q
=
q
.
view
(
-
1
,
tp_q_head_num
,
head_dim
)
k
=
k
.
view
(
-
1
,
tp_k_head_num
,
head_dim
)
v
=
v
.
view
(
-
1
,
tp_k_head_num
,
v_head_dim
)
return
q
,
k
,
v
q
,
k
,
v
=
_create_qkv_tensors_prefill
(
batch_size
,
max_seq_len
,
config
)
# Run prefill on both backends
out_trtllm
=
trtllm_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
fb_trtllm
,
False
).
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
out_reference
=
reference_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
fb_reference
,
False
)
tolerance
=
config
.
get
(
"tolerance"
,
1e-2
)
comparison_passed
=
compare_outputs
(
out_trtllm
,
out_reference
,
tolerance
=
tolerance
)
self
.
assertTrue
(
comparison_passed
,
f
"TRTLLM and Reference prefill outputs differ beyond tolerance. "
f
"Config:
{
test_case
[
'name'
]
}
, "
f
"Max diff:
{
(
out_trtllm
-
out_reference
).
abs
().
max
().
item
()
}
"
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment