Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
666da3d5
Unverified
Commit
666da3d5
authored
Oct 04, 2025
by
Hank Han
Committed by
GitHub
Oct 04, 2025
Browse files
[fix]enable flashmla when using draft model P/D attention select (#11012)
parent
d01b9214
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
5 deletions
+14
-5
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+4
-2
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+7
-0
test/srt/test_flashmla.py
test/srt/test_flashmla.py
+3
-3
No files found.
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
666da3d5
...
...
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
num_q_heads
,
1
,
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
...
...
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
num_q_heads
,
1
,
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
666da3d5
...
...
@@ -244,6 +244,7 @@ class EAGLEWorker(TpModelWorker):
if
not
is_blackwell
()
else
self
.
_create_triton_prefill_backend
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
}
...
...
@@ -383,6 +384,12 @@ class EAGLEWorker(TpModelWorker):
return
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_flashmla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
self
.
cuda_graph_runner
=
None
...
...
test/srt/test_flashmla.py
View file @
666da3d5
...
...
@@ -103,11 +103,11 @@ class TestFlashMLAMTP(CustomTestCase):
"--speculative-draft-model-path"
,
"lmsys/sglang-ci-dsv3-test-NextN"
,
"--speculative-num-steps"
,
"
1
"
,
"
2
"
,
"--speculative-eagle-topk"
,
"1"
,
"--speculative-num-draft-tokens"
,
"
2
"
,
"
3
"
,
"--attention-backend"
,
"flashmla"
,
]
...
...
@@ -146,7 +146,7 @@ class TestFlashMLAMTP(CustomTestCase):
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
1.8
)
self
.
assertGreater
(
avg_spec_accept_length
,
2.4
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment