"tests/vscode:/vscode.git/clone" did not exist on "6620eda357132bcd034c8b5c239fa4527e150c35"
Unverified Commit 64574ef8 authored by pranavm-nvidia's avatar pranavm-nvidia Committed by GitHub
Browse files

Enables speculative decoding for the trtllm_mla attention backend (#9238)

parent 18da2c96
...@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
import triton import triton
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
)
from sglang.srt.layers.attention.utils import ( from sglang.srt.layers.attention.utils import (
TRITON_PAD_NUM_PAGE_PER_BLOCK, TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton, create_flashmla_kv_indices_triton,
...@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state # CUDA graph state
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None self.decode_cuda_graph_kv_indices = None
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int: def _calc_padded_blocks(self, max_seq_len: int) -> int:
...@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_indices_buf: Optional[torch.Tensor] = None, kv_indices_buf: Optional[torch.Tensor] = None,
): ):
"""Initialize CUDA graph state for TRTLLM MLA.""" """Initialize CUDA graph state for TRTLLM MLA."""
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
self.cuda_graph_kv_indices = torch.full( self.decode_cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
) )
self.cuda_graph_workspace = torch.empty( self.decode_cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device self.workspace_size, dtype=torch.int8, device=self.device
) )
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
...@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
"""Initialize metadata for CUDA graph capture.""" """Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None): # Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_capture_cuda_graph( return super().init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
...@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info, spec_info,
) )
# Custom fast-path for decode/idle without speculative execution. # Custom fast-path for decode/idle.
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
create_flashmla_kv_indices_triton[(bs,)]( create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
...@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE=self.page_size, PAGED_SIZE=self.page_size,
) )
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices) metadata = TRTLLMMLADecodeMetadata(
self.decode_cuda_graph_workspace, block_kv_indices
)
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata self.forward_metadata = metadata
...@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
"""Replay CUDA graph with new inputs.""" """Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes or when speculative execution is used. # Delegate to parent for non-decode modes.
if not (forward_mode.is_decode_or_idle() and spec_info is None): if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_replay_cuda_graph( return super().init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
...@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass.""" """Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes or when speculative execution is used. # Delegate to parent for non-decode modes.
if not ( if not forward_batch.forward_mode.is_decode_or_idle():
forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.spec_info is None
):
return super().init_forward_metadata(forward_batch) return super().init_forward_metadata(forward_batch)
bs = forward_batch.batch_size bs = forward_batch.batch_size
...@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
def __init__(
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
):
super().__init__(model_runner, topk, speculative_num_steps)
for i in range(self.speculative_num_steps):
self.attn_backends[i] = TRTLLMMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
...@@ -479,11 +479,6 @@ class ServerArgs: ...@@ -479,11 +479,6 @@ class ServerArgs:
) )
self.page_size = 64 self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]: if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
raise ValueError( raise ValueError(
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto." "TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
......
...@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker): ...@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLAMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMMLABackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}" f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment