Unverified Commit 9ec314c6 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Support speculative decoding in the trtllm_mha attention backend (#9331)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
parent fedfe91c
......@@ -500,11 +500,6 @@ class ServerArgs:
)
self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mha backend does not support speculative decoding yet."
)
if self.attention_backend == "dual_chunk_flash_attn":
logger.warning(
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
......@@ -653,6 +648,16 @@ class ServerArgs:
self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self)
if (
self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha"
or self.prefill_attention_backend == "trtllm_mha"
):
if self.speculative_eagle_topk > 1:
raise ValueError(
"trtllm_mha backend only supports topk = 1 for speculative decoding."
)
if (
self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
......
......@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif self.server_args.attention_backend == "trtllm_mha":
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnBackend,
TRTLLMHAAttnMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment