Unverified Commit 666da3d5 authored by Hank Han's avatar Hank Han Committed by GitHub
Browse files

[fix]enable flashmla when using draft model P/D attention select (#11012)

parent d01b9214
...@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
) )
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_q_heads, num_q_heads,
1, 1,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
...@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
) )
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_q_heads, num_q_heads,
1, 1,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
......
...@@ -244,6 +244,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -244,6 +244,7 @@ class EAGLEWorker(TpModelWorker):
if not is_blackwell() if not is_blackwell()
else self._create_triton_prefill_backend else self._create_triton_prefill_backend
), ),
"flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
} }
...@@ -383,6 +384,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -383,6 +384,12 @@ class EAGLEWorker(TpModelWorker):
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def _create_flashmla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
self.cuda_graph_runner = None self.cuda_graph_runner = None
......
...@@ -103,11 +103,11 @@ class TestFlashMLAMTP(CustomTestCase): ...@@ -103,11 +103,11 @@ class TestFlashMLAMTP(CustomTestCase):
"--speculative-draft-model-path", "--speculative-draft-model-path",
"lmsys/sglang-ci-dsv3-test-NextN", "lmsys/sglang-ci-dsv3-test-NextN",
"--speculative-num-steps", "--speculative-num-steps",
"1", "2",
"--speculative-eagle-topk", "--speculative-eagle-topk",
"1", "1",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"2", "3",
"--attention-backend", "--attention-backend",
"flashmla", "flashmla",
] ]
...@@ -146,7 +146,7 @@ class TestFlashMLAMTP(CustomTestCase): ...@@ -146,7 +146,7 @@ class TestFlashMLAMTP(CustomTestCase):
"avg_spec_accept_length" "avg_spec_accept_length"
] ]
print(f"{avg_spec_accept_length=}") print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8) self.assertGreater(avg_spec_accept_length, 2.4)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment