Unverified Commit 9ec314c6 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Support speculative decoding in the trtllm_mha attention backend (#9331)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
parent fedfe91c
...@@ -500,11 +500,6 @@ class ServerArgs: ...@@ -500,11 +500,6 @@ class ServerArgs:
) )
self.page_size = 64 self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mha backend does not support speculative decoding yet."
)
if self.attention_backend == "dual_chunk_flash_attn": if self.attention_backend == "dual_chunk_flash_attn":
logger.warning( logger.warning(
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend" "Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
...@@ -653,6 +648,16 @@ class ServerArgs: ...@@ -653,6 +648,16 @@ class ServerArgs:
self.speculative_num_draft_tokens, self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self) ) = auto_choose_speculative_params(self)
if (
self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha"
or self.prefill_attention_backend == "trtllm_mha"
):
if self.speculative_eagle_topk > 1:
raise ValueError(
"trtllm_mha backend only supports topk = 1 for speculative decoding."
)
if ( if (
self.speculative_eagle_topk == 1 self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1 and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
......
...@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker): ...@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
elif self.server_args.attention_backend == "trtllm_mha":
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnBackend,
TRTLLMHAAttnMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "trtllm_mla": elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]: if not global_server_args_dict["use_mla_backend"]:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment