Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f35def86
Unverified
Commit
f35def86
authored
Oct 02, 2025
by
fzyzcjy
Committed by
GitHub
Oct 02, 2025
Browse files
Fuse quantize and rope in trtllm_mla MTP (#10779)
parent
d61615fe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
5 deletions
+37
-5
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+33
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+4
-1
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
f35def86
...
...
@@ -568,12 +568,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
cos_sin_cache
:
Optional
[
torch
.
Tensor
]
=
None
,
is_neox
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
if
forward_batch
.
forward_mode
.
is_draft_extend
():
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
# TODO refactor to avoid code duplication
merge_query
=
q_rope
is
not
None
if
(
self
.
data_type
==
torch
.
float8_e4m3fn
)
and
forward_batch
.
forward_mode
.
is_target_verify
():
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
assert
all
(
x
is
not
None
for
x
in
[
q_rope
,
k_rope
,
cos_sin_cache
]
),
"For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
q
,
k
,
k_rope
=
self
.
quantize_and_rope_for_fp8
(
q
,
q_rope
,
k
.
squeeze
(
1
),
k_rope
.
squeeze
(
1
),
forward_batch
,
cos_sin_cache
,
is_neox
,
)
merge_query
=
False
# Save KV cache if requested
if
save_kv_cache
:
assert
(
...
...
@@ -583,12 +606,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
if
q_rope
is
not
None
:
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
# TODO refactor to avoid code duplication
# Prepare query tensor inline
if
merge_query
:
# For FP16 path, we merge the query and rope parts into a single tensor
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope_reshaped
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
q
=
_concat_mla_absorb_q_general
(
q
,
q_rope
)
q
=
_concat_mla_absorb_q_general
(
q_nope
,
q_rope_reshaped
)
else
:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f35def86
...
...
@@ -1399,7 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module):
"""
return
(
self
.
current_attention_backend
==
"trtllm_mla"
and
forward_batch
.
forward_mode
.
is_decode_or_idle
()
and
(
forward_batch
.
forward_mode
.
is_decode_or_idle
()
or
forward_batch
.
forward_mode
.
is_target_verify
()
)
and
forward_batch
.
attn_backend
.
data_type
==
torch
.
float8_e4m3fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment