Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36acd2ff
Unverified
Commit
36acd2ff
authored
Sep 12, 2025
by
Shu Wang
Committed by
GitHub
Sep 12, 2025
Browse files
Fix chunked prefix cache for nvfp4 (#10180)
Co-authored-by:
Elfie Guo
<
elfieg@nvidia.com
>
parent
fe6cdf89
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
2 deletions
+48
-2
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+19
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+27
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
36acd2ff
...
...
@@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
...
...
@@ -72,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
q_indptr_decode_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
q_indptr_decode_buf
)
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
q_indptr_decode_buf
,
)
config
=
model_runner
.
model_config
...
...
@@ -112,6 +118,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self
.
forward_prefill_metadata
:
Optional
[
TRTLLMMLAPrefillMetadata
]
=
None
self
.
forward_decode_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
self
.
disable_chunked_prefix_cache
=
global_server_args_dict
[
"disable_chunked_prefix_cache"
]
def
_calc_padded_blocks
(
self
,
max_seq_len
:
int
)
->
int
:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
...
...
@@ -301,6 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
if
self
.
disable_chunked_prefix_cache
:
super
().
init_forward_metadata
(
forward_batch
)
seq_lens
=
forward_batch
.
seq_lens
-
forward_batch
.
extend_prefix_lens
cum_seq_lens_q
=
torch
.
cat
(
(
...
...
@@ -540,6 +553,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
if
forward_batch
.
attn_attend_prefix_cache
is
None
:
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
if
not
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
36acd2ff
...
...
@@ -560,18 +560,19 @@ class ModelRunner:
if
not
self
.
use_mla_backend
:
server_args
.
disable_chunked_prefix_cache
=
True
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
elif
(
self
.
dp_size
>
1
and
is_sm100_supported
()
and
server_args
.
attention_backend
!=
"triton"
and
server_args
.
attention_backend
==
"trtllm_mla"
):
logger
.
info
(
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
)
server_args
.
disable_chunked_prefix_cache
=
True
if
not
server_args
.
disable_chunked_prefix_cache
:
logger
.
info
(
"Chunked prefix cache is turned on."
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
36acd2ff
...
...
@@ -1087,6 +1087,8 @@ class DeepseekV2AttentionMLA(nn.Module):
disable_ragged
=
(
attention_backend
==
"flashinfer"
or
attention_backend
==
"flashmla"
)
and
self
.
flashinfer_mla_disable_ragged
original_mode
=
getattr
(
forward_batch
,
"_original_forward_mode"
,
None
)
if
(
not
disable_ragged
and
forward_batch
.
forward_mode
.
is_extend
()
...
...
@@ -1099,15 +1101,40 @@ class DeepseekV2AttentionMLA(nn.Module):
)
or
sum_extend_prefix_lens
==
0
)
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
# dp case. Redirect to mla kernel as a workaround.
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
and
not
(
original_mode
is
not
None
and
original_mode
.
is_decode
()
and
is_sm100_supported
()
and
self
.
current_attention_backend
in
(
"cutlass_mla"
,
"flashinfer"
)
)
):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
return
_dispatch_mla_subtype
()
elif
attention_backend
==
"trtllm_mla"
:
original_mode
=
getattr
(
forward_batch
,
"_original_forward_mode"
,
None
)
if
(
original_mode
is
not
None
and
original_mode
.
is_decode
()
and
is_sm100_supported
()
):
return
_dispatch_mla_subtype
()
sum_extend_prefix_lens
=
(
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
if
forward_batch
.
extend_prefix_lens_cpu
is
not
None
else
0
)
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
(
not
self
.
disable_chunked_prefix_cache
or
sum_extend_prefix_lens
==
0
)
):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment