Unverified Commit 5ea96ac7 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Reduce one step decode for draft model. (#11561)

parent 56222658
...@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend: ...@@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend:
device=model_runner.device, device=model_runner.device,
) )
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends.append( self.attn_backends.append(
AiterAttnBackend( AiterAttnBackend(
model_runner, model_runner,
...@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend: ...@@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend:
self.page_size, self.page_size,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1) : seq_lens_sum * self.topk + bs * (i + 1)
...@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend: ...@@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend:
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
......
...@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend: ...@@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends.append( self.attn_backends.append(
FlashAttentionBackend( FlashAttentionBackend(
model_runner, model_runner,
...@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend: ...@@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
self.attn_backends[i].init_forward_metadata(forward_batch) self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
......
...@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend:
(max_bs,), dtype=torch.int32, device=model_runner.device (max_bs,), dtype=torch.int32, device=model_runner.device
) )
self.attn_backends: List[FlashInferAttnBackend] = [] self.attn_backends: List[FlashInferAttnBackend] = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends.append( self.attn_backends.append(
FlashInferAttnBackend( FlashInferAttnBackend(
model_runner, model_runner,
...@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend:
device="cuda", device="cuda",
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
......
...@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend:
) )
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends.append( self.attn_backends.append(
FlashInferMLAAttnBackend( FlashInferMLAAttnBackend(
model_runner, model_runner,
...@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend:
device="cuda", device="cuda",
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
......
...@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend: ...@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
) )
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends.append( self.attn_backends.append(
FlashMLABackend( FlashMLABackend(
model_runner, model_runner,
...@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend: ...@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None max_bs, max_num_tokens, block_kv_indices=None
) )
......
...@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend: ...@@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend:
device=model_runner.device, device=model_runner.device,
) )
self.attn_backends: List[TritonAttnBackend] = [] self.attn_backends: List[TritonAttnBackend] = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends.append( self.attn_backends.append(
TritonAttnBackend( TritonAttnBackend(
model_runner, model_runner,
...@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend: ...@@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend:
if call_fn is None: if call_fn is None:
return return
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1) : seq_lens_sum * self.topk + bs * (i + 1)
...@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend: ...@@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend:
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
for i in range(self.speculative_num_steps):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, max_bs,
max_num_tokens, max_num_tokens,
......
...@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): ...@@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
): ):
super().__init__(model_runner, topk, speculative_num_steps) super().__init__(model_runner, topk, speculative_num_steps)
for i in range(speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i] = TRTLLMHAAttnBackend( self.attn_backends[i] = TRTLLMHAAttnBackend(
model_runner, model_runner,
skip_prefill=True, skip_prefill=True,
...@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): ...@@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self.attn_backends[i].init_forward_metadata(forward_batch) self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
......
...@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): ...@@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
): ):
super().__init__(model_runner, topk, speculative_num_steps) super().__init__(model_runner, topk, speculative_num_steps)
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i] = TRTLLMMLABackend( self.attn_backends[i] = TRTLLMMLABackend(
model_runner, model_runner,
skip_prefill=True, skip_prefill=True,
......
import logging
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils.common import is_blackwell
logger = logging.getLogger(__name__)
class DraftBackendFactory:
def __init__(
self,
server_args: ServerArgs,
draft_model_runner,
topk: int,
speculative_num_steps: int,
):
self.server_args = server_args
self.draft_model_runner = draft_model_runner
self.topk = topk
self.speculative_num_steps = speculative_num_steps
def _create_backend(
self, backend_name: str, backend_map: dict, error_template: str
):
backend_type = getattr(self.server_args, backend_name)
if backend_type is None:
backend_type = self.server_args.attention_backend
if backend_type not in backend_map:
raise ValueError(error_template.format(backend_type=backend_type))
return backend_map[backend_type]()
def create_decode_backend(self):
if self.speculative_num_steps == 1:
class DummyAttnBackend:
def __init__(self):
pass
def init_forward_metadata(*args, **kwargs):
pass
return DummyAttnBackend()
backend_map = {
"flashinfer": self._create_flashinfer_decode_backend,
"triton": self._create_triton_decode_backend,
"aiter": self._create_aiter_decode_backend,
"fa3": self._create_fa3_decode_backend,
"hybrid_linear_attn": (
self._create_fa3_decode_backend
if not is_blackwell()
else self._create_triton_decode_backend
),
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
}
return self._create_backend(
"decode_attention_backend",
backend_map,
"EAGLE is not supported in decode attention backend {backend_type}",
)
def create_draft_extend_backend(self):
backend_map = {
"flashinfer": self._create_flashinfer_prefill_backend,
"triton": self._create_triton_prefill_backend,
"aiter": self._create_aiter_prefill_backend,
"fa3": self._create_fa3_prefill_backend,
"hybrid_linear_attn": (
self._create_fa3_prefill_backend
if not is_blackwell()
else self._create_triton_prefill_backend
),
"flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}
backend_name = (
"decode_attention_backend"
if self.server_args.speculative_attention_mode == "decode"
else "prefill_attention_backend"
)
return self._create_backend(
backend_name,
backend_map,
"EAGLE is not supported in attention backend {backend_type}",
)
def _create_flashinfer_decode_backend(self):
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return FlashInferMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return FlashInferMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_triton_decode_backend(self):
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
return TritonMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_aiter_decode_backend(self):
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
return AiterMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_fa3_decode_backend(self):
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionMultiStepBackend,
)
return FlashAttentionMultiStepBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashmla_decode_backend(self):
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
)
return FlashMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mla_decode_backend(self):
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLAMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashinfer_prefill_backend(self):
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_triton_prefill_backend(self):
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_aiter_prefill_backend(self):
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_fa3_prefill_backend(self):
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mha_prefill_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self):
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def _create_flashmla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
...@@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import ServerArgs, get_global_server_args from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.draft_utils import DraftBackendFactory
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner, EAGLEDraftCudaGraphRunner,
) )
...@@ -195,204 +196,22 @@ class EAGLEWorker(TpModelWorker): ...@@ -195,204 +196,22 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = None
# Initialize decode attention backend draft_backend_factory = DraftBackendFactory(
self.draft_attn_backend = self._create_decode_backend() self.server_args,
self.draft_model_runner,
# Initialize draft extend attention backend (respects speculative_attention_mode setting) self.topk,
self.draft_extend_attn_backend = self._create_draft_extend_backend() self.speculative_num_steps,
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
def _create_backend(
self, backend_name: str, backend_map: dict, error_template: str
):
backend_type = getattr(self.server_args, backend_name)
if backend_type is None:
backend_type = self.server_args.attention_backend
if backend_type not in backend_map:
raise ValueError(error_template.format(backend_type=backend_type))
return backend_map[backend_type]()
def _create_decode_backend(self):
backend_map = {
"flashinfer": self._create_flashinfer_decode_backend,
"triton": self._create_triton_decode_backend,
"aiter": self._create_aiter_decode_backend,
"fa3": self._create_fa3_decode_backend,
"hybrid_linear_attn": (
self._create_fa3_decode_backend
if not is_blackwell()
else self._create_triton_decode_backend
),
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
}
return self._create_backend(
"decode_attention_backend",
backend_map,
"EAGLE is not supported in decode attention backend {backend_type}",
)
def _create_draft_extend_backend(self):
backend_map = {
"flashinfer": self._create_flashinfer_prefill_backend,
"triton": self._create_triton_prefill_backend,
"aiter": self._create_aiter_prefill_backend,
"fa3": self._create_fa3_prefill_backend,
"hybrid_linear_attn": (
self._create_fa3_prefill_backend
if not is_blackwell()
else self._create_triton_prefill_backend
),
"flashmla": self._create_flashmla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}
backend_name = (
"decode_attention_backend"
if self.server_args.speculative_attention_mode == "decode"
else "prefill_attention_backend"
)
return self._create_backend(
backend_name,
backend_map,
"EAGLE is not supported in attention backend {backend_type}",
)
def _create_flashinfer_decode_backend(self):
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return FlashInferMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return FlashInferMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_triton_decode_backend(self):
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
return TritonMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_aiter_decode_backend(self):
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
return AiterMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_fa3_decode_backend(self):
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionMultiStepBackend,
)
return FlashAttentionMultiStepBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashmla_decode_backend(self):
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
)
return FlashMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mla_decode_backend(self):
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLAMultiStepDraftBackend,
)
self.has_prefill_wrapper_verify = True
return TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
) )
def _create_flashinfer_prefill_backend(self): # Initialize decode attention backend
if not get_global_server_args().use_mla_backend: self.draft_attn_backend = draft_backend_factory.create_decode_backend()
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_triton_prefill_backend(self):
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_aiter_prefill_backend(self):
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_fa3_prefill_backend(self): # Initialize draft extend attention backend (respects speculative_attention_mode setting)
from sglang.srt.layers.attention.flashattention_backend import ( self.draft_extend_attn_backend = (
FlashAttentionBackend, draft_backend_factory.create_draft_extend_backend()
) )
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False) self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
def _create_trtllm_mha_prefill_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self):
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def _create_flashmla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
)
return None
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment