"git@developer.sourcefind.cn:change/sglang.git" did not exist on "a7000a765041cf870bb9964ee533dd0fb7cebcdf"
Unverified Commit df397a72 authored by Ximingwang-09's avatar Ximingwang-09 Committed by GitHub
Browse files

[feat] Add P/D attention select for draft model (#9755)


Co-authored-by: default avatar纬杭 <ximing.wxm@antgroup.com>
parent 5dfcd6c2
...@@ -187,137 +187,183 @@ class EAGLEWorker(TpModelWorker): ...@@ -187,137 +187,183 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = None
if self.server_args.attention_backend == "flashinfer": # Initialize decode attention backend
if not global_server_args_dict["use_mla_backend"]: self.draft_attn_backend = self._create_decode_backend()
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
FlashInferMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend( # Initialize prefill attention backend
self.draft_model_runner, self.draft_extend_attn_backend = self._create_draft_extend_backend()
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = FlashInferAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonAttnBackend,
TritonMultiStepDraftBackend,
)
self.draft_attn_backend = TritonMultiStepDraftBackend( def _create_backend(
self.draft_model_runner, self, backend_name: str, backend_map: dict, error_template: str
self.topk, ):
self.speculative_num_steps, backend_type = getattr(self.server_args, backend_name)
) if backend_type is None:
self.draft_extend_attn_backend = TritonAttnBackend( backend_type = self.server_args.attention_backend
self.draft_model_runner,
skip_prefill=False, if backend_type not in backend_map:
) raise ValueError(error_template.format(backend_type=backend_type))
elif self.server_args.attention_backend == "aiter":
from sglang.srt.layers.attention.aiter_backend import ( return backend_map[backend_type]()
AiterAttnBackend,
AiterMultiStepDraftBackend, def _create_decode_backend(self):
) backend_map = {
"flashinfer": self._create_flashinfer_decode_backend,
"triton": self._create_triton_decode_backend,
"aiter": self._create_aiter_decode_backend,
"fa3": self._create_fa3_decode_backend,
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
}
return self._create_backend(
"decode_attention_backend",
backend_map,
"EAGLE is not supported in decode attention backend {backend_type}",
)
self.draft_attn_backend = AiterMultiStepDraftBackend( def _create_draft_extend_backend(self):
self.draft_model_runner, backend_map = {
self.topk, "flashinfer": self._create_flashinfer_prefill_backend,
self.speculative_num_steps, "triton": self._create_triton_prefill_backend,
) "aiter": self._create_aiter_prefill_backend,
self.draft_extend_attn_backend = AiterAttnBackend( "fa3": self._create_fa3_prefill_backend,
self.draft_model_runner, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
skip_prefill=False, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
) }
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3": return self._create_backend(
from sglang.srt.layers.attention.flashattention_backend import ( "prefill_attention_backend",
FlashAttentionBackend, backend_map,
FlashAttentionMultiStepBackend, "EAGLE is not supported in prefill attention backend {backend_type}",
) )
self.draft_attn_backend = FlashAttentionMultiStepBackend( def _create_flashinfer_decode_backend(self):
self.draft_model_runner, if not global_server_args_dict["use_mla_backend"]:
self.topk, from sglang.srt.layers.attention.flashinfer_backend import (
self.speculative_num_steps, FlashInferMultiStepDraftBackend,
)
self.draft_extend_attn_backend = FlashAttentionBackend(
self.draft_model_runner,
skip_prefill=False,
)
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
) )
self.draft_attn_backend = FlashMLAMultiStepDraftBackend( self.has_prefill_wrapper_verify = True
self.draft_model_runner, return FlashInferMultiStepDraftBackend(
self.topk, self.draft_model_runner, self.topk, self.speculative_num_steps
self.speculative_num_steps,
) )
elif self.server_args.attention_backend == "trtllm_mha": else:
from sglang.srt.layers.attention.trtllm_mha_backend import ( from sglang.srt.layers.attention.flashinfer_mla_backend import (
TRTLLMHAAttnBackend, FlashInferMLAMultiStepDraftBackend,
TRTLLMHAAttnMultiStepDraftBackend,
) )
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "trtllm_mla": return FlashInferMLAMultiStepDraftBackend(
if not global_server_args_dict["use_mla_backend"]: self.draft_model_runner, self.topk, self.speculative_num_steps
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLAMultiStepDraftBackend,
) )
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend( def _create_triton_decode_backend(self):
self.draft_model_runner, from sglang.srt.layers.attention.triton_backend import (
self.topk, TritonMultiStepDraftBackend,
self.speculative_num_steps, )
return TritonMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_aiter_decode_backend(self):
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
return AiterMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_fa3_decode_backend(self):
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionMultiStepBackend,
)
return FlashAttentionMultiStepBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashmla_decode_backend(self):
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
)
return FlashMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mla_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
) )
self.draft_extend_attn_backend = TRTLLMMLABackend(
self.draft_model_runner, from sglang.srt.layers.attention.trtllm_mla_backend import (
skip_prefill=False, TRTLLMMLAMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashinfer_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
) )
self.has_prefill_wrapper_verify = True
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
else: else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_triton_prefill_backend(self):
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_aiter_prefill_backend(self):
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_fa3_prefill_backend(self):
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mha_prefill_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]:
raise ValueError( raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}" "trtllm_mla backend requires MLA model (use_mla_backend=True)."
) )
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment