Unverified Commit 40e3b2be authored by eigen's avatar eigen Committed by GitHub
Browse files

feat: add trtllm-gen mha from direct call (#8782)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent 75df31b6
from __future__ import annotations
from python.sglang.srt.layers.radix_attention import RadixAttention
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
import flashinfer
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
# Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
@dataclass
class TRTLLMMHAMetadata:
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 1
# Maximum sequence length for key
max_seq_len_k: int = 0
# Cumulative sequence lengths for `query
cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
class TRTLLMHAAttnBackend(FlashInferAttnBackend):
"""TRTLLM MHA attention kernel from flashinfer."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
config = model_runner.model_config
# MHA-specific dimensions
self.max_context_len = model_runner.model_config.context_len
self.sliding_window_size = (
model_runner.sliding_window_size
if model_runner.sliding_window_size is not None
else -1 # -1 indicates full attention
)
self.hidden_size = config.hidden_size
# Runtime parameters
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.page_size = model_runner.page_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.device = model_runner.device
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
# CUDA graph state
self.decode_cuda_graph_metadata = {}
# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
"""Initialize CUDA graph state for TRTLLM MHA."""
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
"""Initialize metadata for CUDA graph capture."""
metadata = TRTLLMMHAMetadata()
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device
metadata = None
# Normal Decode
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self) -> int:
"""Get the fill value for sequence lengths in CUDA graph."""
return 1
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
metadata = TRTLLMMHAMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = forward_batch.batch_size
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle():
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if any(forward_batch.extend_prefix_lens_cpu):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
else:
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Convert the page table to a strided format
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion:
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim]
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
v_cache = v_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)
# TODO: bmm1_scale and bmm2_scale might require modification
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0
# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.forward_metadata.max_seq_len_k,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=self.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
v_cache = v_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)
# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: Change once quantization is supported
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_q_len=self.forward_metadata.max_seq_len_q,
max_kv_len=self.forward_metadata.max_seq_len_k,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
batch_size=forward_batch.batch_size,
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=self.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -1705,6 +1705,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1705,6 +1705,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or attention_backend_str == "flashmla" or attention_backend_str == "flashmla"
or attention_backend_str == "cutlass_mla" or attention_backend_str == "cutlass_mla"
or attention_backend_str == "ascend" or attention_backend_str == "ascend"
or attention_backend_str == "trtllm_mha"
or global_server_args_dict["enable_two_batch_overlap"] or global_server_args_dict["enable_two_batch_overlap"]
): ):
seq_lens_cpu = ( seq_lens_cpu = (
......
...@@ -1449,6 +1449,17 @@ class ModelRunner: ...@@ -1449,6 +1449,17 @@ class ModelRunner:
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self) return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mha":
if self.use_mla_backend:
raise ValueError(
"trtllm_mha backend can only be used with non-MLA models."
)
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnBackend,
)
return TRTLLMHAAttnBackend(self)
elif self.server_args.attention_backend == "intel_amx": elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import ( from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend, IntelAMXAttnBackend,
......
...@@ -441,6 +441,23 @@ class ServerArgs: ...@@ -441,6 +441,23 @@ class ServerArgs:
"trtllm_mla backend does not support speculative decoding yet." "trtllm_mla backend does not support speculative decoding yet."
) )
if self.attention_backend == "trtllm_mha":
if not is_sm100_supported():
raise ValueError(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
)
if self.page_size not in [16, 32, 64]:
logger.warning(
f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
)
self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
# Set page size # Set page size
if self.page_size is None: if self.page_size is None:
self.page_size = 1 self.page_size = 1
...@@ -1275,6 +1292,7 @@ class ServerArgs: ...@@ -1275,6 +1292,7 @@ class ServerArgs:
"ascend", "ascend",
"triton", "triton",
"trtllm_mla", "trtllm_mla",
"trtllm_mha",
], ],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment