Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ff9b5618
Unverified
Commit
ff9b5618
authored
Aug 29, 2025
by
Faraz
Committed by
GitHub
Aug 29, 2025
Browse files
Fix TRTLLM MLA Cuda KV Blocks Causing accuracy drop (#9675)
parent
fcd72bd1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
13 deletions
+37
-13
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+25
-10
python/sglang/test/attention/test_trtllm_mla_backend.py
python/sglang/test/attention/test_trtllm_mla_backend.py
+12
-3
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
ff9b5618
...
@@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata:
...
@@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata:
workspace
:
Optional
[
torch
.
Tensor
]
=
None
workspace
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
class
TRTLLMMLABackend
(
FlashInferMLAAttnBackend
):
class
TRTLLMMLABackend
(
FlashInferMLAAttnBackend
):
...
@@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
)
)
# Custom fast-path for decode/idle.
# Custom fast-path for decode/idle.
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
seq_lens
.
max
().
item
())
# Capture with full width so future longer sequences are safe during replay
block_kv_indices
=
self
.
decode_cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
]
max_blocks_per_seq
=
self
.
_calc_padded_blocks
(
self
.
max_context_len
)
block_kv_indices
=
self
.
decode_cuda_graph_kv_indices
[:
bs
,
:
max_blocks_per_seq
]
create_flashmla_kv_indices_triton
[(
bs
,)](
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
None
,
None
,
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
max_
seqlen_pad
,
max_
blocks_per_seq
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
NUM_PAGE_PER_BLOCK
=
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
# Record the true maximum sequence length for this capture batch so that
# the kernel launch path (which requires an int not a tensor) can reuse
# it safely during both capture and replay.
max_seq_len_val
=
int
(
seq_lens
.
max
().
item
())
metadata
=
TRTLLMMLADecodeMetadata
(
metadata
=
TRTLLMMLADecodeMetadata
(
self
.
decode_cuda_graph_workspace
,
block_kv_indices
self
.
decode_cuda_graph_workspace
,
block_kv_indices
,
max_seq_len_val
,
)
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
...
@@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
# Update stored max_seq_len so subsequent kernel calls use the correct value
# Prefer CPU tensor to avoid GPU synchronization when available.
if
seq_lens_cpu
is
not
None
:
metadata
.
max_seq_len
=
int
(
seq_lens_cpu
.
max
().
item
())
else
:
metadata
.
max_seq_len
=
int
(
seq_lens
.
max
().
item
())
def
get_cuda_graph_seq_len_fill_value
(
self
)
->
int
:
def
get_cuda_graph_seq_len_fill_value
(
self
)
->
int
:
"""Get the fill value for sequence lengths in CUDA graph."""
"""Get the fill value for sequence lengths in CUDA graph."""
return
1
return
1
...
@@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
forward_batch
.
seq_lens
.
device
,
forward_batch
.
seq_lens
.
device
,
)
)
max_seq_len_val
=
int
(
max_seq
)
self
.
forward_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
forward_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
workspace_buffer
,
block_kv_indices
self
.
workspace_buffer
,
block_kv_indices
,
max_seq_len_val
)
)
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_metadata
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_metadata
...
@@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
block_tables
=
metadata
.
block_kv_indices
,
block_tables
=
metadata
.
block_kv_indices
,
seq_lens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
seq_lens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
max_seq_len
=
int
(
metadata
.
block_kv_indices
.
shape
[
1
]
*
self
.
page_size
)
,
max_seq_len
=
metadata
.
max_seq_len
,
bmm1_scale
=
bmm1_scale
,
bmm1_scale
=
bmm1_scale
,
)
)
# Extract value projection part and reshape
# Reshape output directly without slicing
raw_out_v
=
raw_out
[...,
:
layer
.
v_head_dim
].
contiguous
()
output
=
raw_out
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
output
=
raw_out_v
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
return
output
...
...
python/sglang/test/attention/test_trtllm_mla_backend.py
View file @
ff9b5618
...
@@ -208,6 +208,15 @@ class MockModelRunner:
...
@@ -208,6 +208,15 @@ class MockModelRunner:
self
.
kv_cache_dtype
=
config
[
"kv_cache_dtype"
]
self
.
kv_cache_dtype
=
config
[
"kv_cache_dtype"
]
self
.
page_size
=
config
[
"page_size"
]
self
.
page_size
=
config
[
"page_size"
]
# Server args stub - needed by attention backends
self
.
server_args
=
type
(
"ServerArgs"
,
(),
{
"enable_dp_attention"
:
False
,
# Default value for testing
},
)
# Model-config stub with MLA attributes
# Model-config stub with MLA attributes
self
.
model_config
=
type
(
self
.
model_config
=
type
(
"ModelConfig"
,
"ModelConfig"
,
...
@@ -833,7 +842,7 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -833,7 +842,7 @@ class TestTRTLLMMLA(CustomTestCase):
# Test workspace properties
# Test workspace properties
self
.
assertEqual
(
metadata
.
workspace
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
workspace
.
device
.
type
,
"cuda"
)
self
.
assertEqual
(
metadata
.
workspace
.
dtype
,
torch
.
int8
)
self
.
assertEqual
(
metadata
.
workspace
.
dtype
,
torch
.
u
int8
)
self
.
assertGreater
(
self
.
assertGreater
(
metadata
.
workspace
.
numel
(),
0
,
"Workspace should have non-zero size"
metadata
.
workspace
.
numel
(),
0
,
"Workspace should have non-zero size"
)
)
...
@@ -993,8 +1002,8 @@ class TestTRTLLMMLA(CustomTestCase):
...
@@ -993,8 +1002,8 @@ class TestTRTLLMMLA(CustomTestCase):
)
)
# Verify CUDA graph buffers are allocated
# Verify CUDA graph buffers are allocated
self
.
assertIsNotNone
(
backend
.
cuda_graph_kv_indices
)
self
.
assertIsNotNone
(
backend
.
decode_
cuda_graph_kv_indices
)
self
.
assertIsNotNone
(
backend
.
cuda_graph_workspace
)
self
.
assertIsNotNone
(
backend
.
decode_
cuda_graph_workspace
)
# Test capture metadata
# Test capture metadata
seq_lens
=
torch
.
full
(
seq_lens
=
torch
.
full
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment