Unverified Commit 9ec314c6 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Support speculative decoding in the trtllm_mha attention backend (#9331)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
parent fedfe91c
...@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional ...@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
FlashInferMultiStepDraftBackend,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
import flashinfer import flashinfer
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
model_runner: ModelRunner, model_runner: ModelRunner,
skip_prefill: bool = False, skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None, kv_last_page_len_buf: Optional[torch.Tensor] = None,
speculative_step_id: int = 0,
): ):
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) super().__init__(
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
)
config = model_runner.model_config config = model_runner.model_config
...@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# CUDA graph state # CUDA graph state
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
# Speculative decoding
# Only support topk <= 1 for now.
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_step_id = speculative_step_id
self.target_verify_metadata = {}
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
# Forward metadata # Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
...@@ -97,11 +115,76 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -97,11 +115,76 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
kv_indices_buf: Optional[torch.Tensor] = None, kv_indices_buf: Optional[torch.Tensor] = None,
): ):
"""Initialize CUDA graph state for TRTLLM MHA.""" """Initialize CUDA graph state for TRTLLM MHA."""
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size, max_num_pages,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
if (
self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0
):
self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
)
self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
)
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
)
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
self.draft_extend_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.zeros(
max_bs + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
max_num_pages,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
...@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
): ):
"""Initialize metadata for CUDA graph capture.""" """Initialize metadata for CUDA graph capture."""
metadata = TRTLLMMHAMetadata() metadata = TRTLLMMHAMetadata()
device = seq_lens.device
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
# Here we only support topk = 1 for now.
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens"
][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
: bs + 1
]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][:bs, :]
self.decode_cuda_graph_metadata[bs] = metadata
else:
# Normal Decode
# Get sequence information # Get sequence information
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
batch_size = len(seq_lens)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length # Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens[:bs].max().item() metadata.max_seq_len_k = seq_lens.max().item()
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
# Precompute page table # Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
:bs, :
]
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify():
# Target Verify
# Here we only support topk = 1 for now.
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens)
)
metadata.cu_seqlens_q = torch.arange(
0,
bs * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
seq_lens.max().item() + self.speculative_num_draft_tokens
)
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
self.target_verify_metadata[bs] = metadata
elif forward_mode.is_draft_extend():
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(seq_lens)
num_tokens_per_bs = num_tokens // bs
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
: (bs + 1)
]
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.max_seq_len_k = seq_lens.max().item()
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
self.draft_extend_metadata[bs] = metadata
self.forward_metadata = metadata self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
...@@ -149,9 +321,23 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -149,9 +321,23 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs] seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs] req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device
metadata = None metadata = None
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
# Here we only support topk = 1 for now.
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
metadata.cache_seqlens_int32.copy_(
seq_lens + self.speculative_step_id + 1
)
else:
# Normal Decode # Normal Decode
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item() max_len = seq_lens_cpu.max().item()
...@@ -159,9 +345,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -159,9 +345,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata.max_seq_len_k = max_len metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[ page_indices = self.req_to_token[
req_pool_indices[:, None], req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :], self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
elif forward_mode.is_target_verify():
# Here we only support topk = 1 for now.
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens)
)
metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
max_len = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)
metadata.max_seq_len_k = seq_lens_cpu.max().item()
max_len = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
else:
metadata.max_seq_len_q = 1
metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
] ]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
self.forward_metadata = metadata self.forward_metadata = metadata
...@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
device = seqlens_in_batch.device device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
if forward_batch.spec_info is not None:
# Draft Decode
# Here we only support topk = 1 for now.
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
# Normal Decode # Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
elif forward_batch.forward_mode.is_target_verify():
# Only support topk = 1 for now.
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
else: else:
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
...@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
if any(forward_batch.extend_prefix_lens_cpu): if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
...@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
"""Multi-step TRTLLM MHA attention kernel used by EAGLE."""
def __init__(
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
):
super().__init__(model_runner, topk, speculative_num_steps)
for i in range(speculative_num_steps):
self.attn_backends[i] = TRTLLMHAAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=self.kv_last_page_len,
speculative_step_id=i,
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
...@@ -500,11 +500,6 @@ class ServerArgs: ...@@ -500,11 +500,6 @@ class ServerArgs:
) )
self.page_size = 64 self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mha backend does not support speculative decoding yet."
)
if self.attention_backend == "dual_chunk_flash_attn": if self.attention_backend == "dual_chunk_flash_attn":
logger.warning( logger.warning(
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend" "Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
...@@ -653,6 +648,16 @@ class ServerArgs: ...@@ -653,6 +648,16 @@ class ServerArgs:
self.speculative_num_draft_tokens, self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self) ) = auto_choose_speculative_params(self)
if (
self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha"
or self.prefill_attention_backend == "trtllm_mha"
):
if self.speculative_eagle_topk > 1:
raise ValueError(
"trtllm_mha backend only supports topk = 1 for speculative decoding."
)
if ( if (
self.speculative_eagle_topk == 1 self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1 and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
......
...@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker): ...@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
elif self.server_args.attention_backend == "trtllm_mha":
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnBackend,
TRTLLMHAAttnMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "trtllm_mla": elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]: if not global_server_args_dict["use_mla_backend"]:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment