Unverified Commit 64574ef8 authored by pranavm-nvidia's avatar pranavm-nvidia Committed by GitHub
Browse files

Enables speculative decoding for the trtllm_mla attention backend (#9238)

parent 18da2c96
......@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
)
from sglang.srt.layers.attention.utils import (
TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton,
......@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state
self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None
self.decode_cuda_graph_kv_indices = None
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int:
......@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_indices_buf: Optional[torch.Tensor] = None,
):
"""Initialize CUDA graph state for TRTLLM MLA."""
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
self.cuda_graph_kv_indices = torch.full(
self.decode_cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
)
self.cuda_graph_workspace = torch.empty(
self.decode_cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
......@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info: Optional[SpecInfo],
):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None):
# Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
......@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info,
)
# Custom fast-path for decode/idle without speculative execution.
# Custom fast-path for decode/idle.
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
......@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE=self.page_size,
)
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
metadata = TRTLLMMLADecodeMetadata(
self.decode_cuda_graph_workspace, block_kv_indices
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
......@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None):
# Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
......@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (
forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.spec_info is None
):
# Delegate to parent for non-decode modes.
if not forward_batch.forward_mode.is_decode_or_idle():
return super().init_forward_metadata(forward_batch)
bs = forward_batch.batch_size
......@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
def __init__(
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
):
super().__init__(model_runner, topk, speculative_num_steps)
for i in range(self.speculative_num_steps):
self.attn_backends[i] = TRTLLMMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
......@@ -479,11 +479,6 @@ class ServerArgs:
)
self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
raise ValueError(
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
......
......@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLAMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMMLABackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
else:
raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment