Unverified Commit 4b04998d authored by Faraz's avatar Faraz Committed by GitHub
Browse files

TRTLLM Gen MLA Decode Kernel Integration (same as #7938) (#8632)


Signed-off-by: default avatarFaraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
parent 3dde8619
...@@ -9,8 +9,12 @@ ...@@ -9,8 +9,12 @@
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | | **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ |
| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ | | **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ |
**Notes:**
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1. The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.
...@@ -48,6 +52,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti ...@@ -48,6 +52,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
``` ```
- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)
```bash
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
```
- Ascend - Ascend
```bash ```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
......
...@@ -90,7 +90,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -90,7 +90,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads. - **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
...@@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in ...@@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in
<img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models"> <img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models">
</p> </p>
**Usage**: MLA optimization is enabled by default. **Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B200), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA.
**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. **Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.
...@@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati ...@@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8 python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8
``` ```
- The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development. - FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backends are still under development.
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
......
from __future__ import annotations
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
"""
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import (
TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
import flashinfer
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
# Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
# Block constraint from flashinfer requirements
# From flashinfer.decode._check_trtllm_gen_mla_shape:
# block_num % (128 / block_size) == 0
# This imposes that the total number of blocks must be divisible by
# (128 / block_size). We capture the 128 constant here so we can
# compute the LCM with other padding constraints.
TRTLLM_BLOCK_CONSTRAINT = 128
@dataclass
class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""TRTLLM MLA attention kernel from flashinfer."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
config = model_runner.model_config
# Model parameters
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
# MLA-specific dimensions
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
# Runtime parameters
self.scaling = config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.page_size = model_runner.page_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
# CUDA graph state
self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
Args:
max_seq_len: Maximum sequence length in tokens
Returns:
Number of blocks padded to satisfy all constraints
"""
blocks = triton.cdiv(max_seq_len, self.page_size)
# Apply dual constraints (take LCM to satisfy both):
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
if blocks % constraint_lcm != 0:
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
return blocks
def _create_block_kv_indices(
self,
batch_size: int,
max_blocks: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
"""
Create block KV indices tensor using Triton kernel.
Args:
batch_size: Batch size
max_blocks: Maximum number of blocks per sequence
req_pool_indices: Request pool indices
seq_lens: Sequence lengths
device: Target device
Returns:
Block KV indices tensor
"""
block_kv_indices = torch.full(
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
)
create_flashmla_kv_indices_triton[(batch_size,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_blocks,
TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size,
)
return block_kv_indices
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
"""Initialize CUDA graph state for TRTLLM MLA."""
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
self.cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
)
self.cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None):
return super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
# Custom fast-path for decode/idle without speculative execution.
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size,
)
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (forward_mode.is_decode_or_idle() and spec_info is None):
return super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
metadata = self.decode_cuda_graph_metadata[bs]
# Update block indices for new sequences.
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
None,
metadata.block_kv_indices,
self.req_to_token.stride(0),
metadata.block_kv_indices.shape[1],
TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size,
)
def get_cuda_graph_seq_len_fill_value(self) -> int:
"""Get the fill value for sequence lengths in CUDA graph."""
return 1
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if not (
forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.spec_info is None
):
return super().init_forward_metadata(forward_batch)
bs = forward_batch.batch_size
# Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
else:
max_seq = forward_batch.seq_lens.max().item()
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
self.forward_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices
)
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MLA kernel."""
# Save KV cache if requested
if k is not None and save_kv_cache:
cache_loc = forward_batch.out_cache_loc
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope
)
elif v is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
# Prepare query tensor inline
if q_rope is not None:
# q contains NOPE part (v_head_dim)
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else:
# q already has both parts
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
if query.dim() == 3:
query = query.unsqueeze(1)
# Prepare KV cache inline
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
# TRT-LLM expects single KV data with extra dimension
kv_cache = pages.unsqueeze(1)
# Get metadata
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_metadata
)
# Scale computation for TRTLLM MLA kernel:
# - BMM1 scale = q_scale * k_scale * softmax_scale
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
# - k_scale is read from model checkpoint if available
# TODO: Change once fp8 path is supported
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
# Call TRT-LLM kernel
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=metadata.workspace,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices,
seq_lens=forward_batch.seq_lens.to(torch.int32),
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
bmm1_scale=bmm1_scale,
)
# Extract value projection part and reshape
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
import triton import triton
import triton.language as tl import triton.language as tl
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
# Number of pages that the kernel writes per iteration.
# Exposed here so other Python modules can import it instead of hard-coding 64.
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
@triton.jit @triton.jit
def create_flashinfer_kv_indices_triton( def create_flashinfer_kv_indices_triton(
...@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton( ...@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr, kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr, req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr,
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
PAGED_SIZE: tl.constexpr = 64, PAGED_SIZE: tl.constexpr = 64,
): ):
BLOCK_SIZE: tl.constexpr = 4096 BLOCK_SIZE: tl.constexpr = 4096
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token # find the req pool idx, this is for batch to token
......
...@@ -436,6 +436,7 @@ class ModelRunner: ...@@ -436,6 +436,7 @@ class ModelRunner:
"triton", "triton",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
"trtllm_mla",
"ascend", "ascend",
]: ]:
logger.info( logger.info(
...@@ -1437,6 +1438,12 @@ class ModelRunner: ...@@ -1437,6 +1438,12 @@ class ModelRunner:
) )
return CutlassMLABackend(self) return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla":
if not self.use_mla_backend:
raise ValueError("trtllm_mla backend can only be used with MLA models.")
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "intel_amx": elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import ( from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend, IntelAMXAttnBackend,
......
...@@ -1259,6 +1259,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1259,6 +1259,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.current_attention_backend == "fa3" self.current_attention_backend == "fa3"
or self.current_attention_backend == "flashinfer" or self.current_attention_backend == "flashinfer"
or self.current_attention_backend == "cutlass_mla" or self.current_attention_backend == "cutlass_mla"
or self.current_attention_backend == "trtllm_mla"
): ):
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
......
...@@ -24,6 +24,7 @@ import tempfile ...@@ -24,6 +24,7 @@ import tempfile
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -402,6 +403,22 @@ class ServerArgs: ...@@ -402,6 +403,22 @@ class ServerArgs:
) )
self.page_size = 128 self.page_size = 128
if self.attention_backend == "trtllm_mla":
if not is_sm100_supported():
raise ValueError(
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
)
if self.page_size not in [32, 64]:
logger.warning(
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
)
self.page_size = 64
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
# Set page size # Set page size
if self.page_size is None: if self.page_size is None:
self.page_size = 1 self.page_size = 1
...@@ -1225,6 +1242,7 @@ class ServerArgs: ...@@ -1225,6 +1242,7 @@ class ServerArgs:
"torch_native", "torch_native",
"ascend", "ascend",
"triton", "triton",
"trtllm_mla",
], ],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
......
import math
import unittest
import numpy as np
import torch
from sglang.srt.layers import dp_attention as _dp_attn
# Patch DP-attention globals before importing backends
# TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching
_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLADecodeMetadata,
)
from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
from sglang.test.test_utils import CustomTestCase
# Global configuration for all tests
DEFAULT_CONFIG = {
"device": "cuda",
"dtype": torch.bfloat16,
"kv_cache_dtype": torch.bfloat16,
"context_len": 2048,
"max_bs": 64,
"tolerance": 1e-2,
"seed_cache": 42,
"seed_qkv": 123,
# MLA model config (TRTLLM MLA has fixed constraints)
"num_attention_heads": 128,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 512,
"num_kv_heads": 1,
"layer_id": 0,
}
# Centralized test cases for different test scenarios
TEST_CASES = {
"basic_functionality": [
{
"name": "single",
"batch_size": 1,
"max_seq_len": 32,
"page_size": 32,
"description": "Minimal smoke test",
},
{
"name": "batch",
"batch_size": 32,
"max_seq_len": 128,
"page_size": 32,
"description": "Medium-scale batch",
},
],
"decode_output_match": [
{
"name": "single",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 32,
"description": "Single vs reference",
},
{
"name": "batch",
"batch_size": 32,
"max_seq_len": 64,
"page_size": 32,
"description": "Batch vs reference",
},
],
"page_size_consistency": [
# Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
{
"name": "page_32",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 32,
"description": "32-token pages",
},
{
"name": "page_64",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 64,
"description": "64-token pages",
},
],
"shape_sanity_tests": [
{
"name": "basic",
"batch_size": 1,
"max_seq_len": 128,
"page_size": 32,
"description": "Single sequence",
},
{
"name": "basic_different_pagesize",
"batch_size": 1,
"max_seq_len": 128,
"page_size": 64,
"description": "Different page size",
},
{
"name": "batch",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 32,
"description": "Batch shapes",
},
],
"metadata_tests": [
{
"name": "single_sequence",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 32,
"description": "Single sequence metadata",
},
{
"name": "batch_mixed_lengths",
"batch_size": 8,
"max_seq_len": 128,
"page_size": 32,
"description": "Mixed sequence lengths",
},
{
"name": "large_batch",
"batch_size": 32,
"max_seq_len": 256,
"page_size": 64,
"description": "Large batch stress test",
},
{
"name": "edge_case_short",
"batch_size": 4,
"max_seq_len": 16,
"page_size": 32,
"description": "Sub-page sequences",
},
],
}
class MockModelRunner:
"""Minimal fake ModelRunner for testing MLA backends."""
def __init__(self, config):
self.device = config["device"]
self.dtype = config["dtype"]
self.kv_cache_dtype = config["kv_cache_dtype"]
self.page_size = config["page_size"]
# Model-config stub with MLA attributes
self.model_config = type(
"ModelConfig",
(),
{
"context_len": config["context_len"],
"attention_arch": AttentionArch.MLA,
"num_attention_heads": config["num_attention_heads"],
"kv_lora_rank": config["kv_lora_rank"],
"qk_nope_head_dim": config["qk_nope_head_dim"],
"qk_rope_head_dim": config["qk_rope_head_dim"],
"v_head_dim": config["v_head_dim"],
"scaling": 1.0
/ ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
"get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
},
)
# Req-to-token pool
max_bs = config["max_bs"]
max_ctx = self.model_config.context_len
req_to_token = torch.arange(
max_bs * max_ctx, dtype=torch.int32, device=self.device
).reshape(max_bs, max_ctx)
self.req_to_token_pool = type(
"TokenPool",
(),
{
"size": max_bs,
"req_to_token": req_to_token,
},
)
# KV-token pool (MLA)
self.token_to_kv_pool = MLATokenToKVPool(
size=max_bs * max_ctx,
page_size=config["page_size"],
dtype=self.kv_cache_dtype,
kv_lora_rank=config["kv_lora_rank"],
qk_rope_head_dim=config["qk_rope_head_dim"],
layer_num=1,
device=self.device,
enable_memory_saver=False,
)
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
"""Compare outputs with detailed analysis."""
# Basic checks
assert (
trtllm_out.shape == reference_out.shape
), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
assert (
trtllm_out.dtype == reference_out.dtype
), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
# Check for NaN/Inf
assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
# Element-wise differences
diff = (trtllm_out - reference_out).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Check numerical equivalence
all_close = torch.allclose(
trtllm_out, reference_out, rtol=tolerance, atol=tolerance
)
if not all_close:
print(
f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}"
)
# Find top differences for debugging
flat_diff = diff.flatten()
top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
print("Top 5 differences:")
for i, idx in enumerate(top_diff_indices):
idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
trt_val = trtllm_out[idx_tuple].item()
ref_val = reference_out[idx_tuple].item()
print(
f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}"
)
return all_close
@unittest.skipIf(
not torch.cuda.is_available() or not is_flashinfer_available(),
"CUDA + flashinfer required",
)
class TestTRTLLMMLA(CustomTestCase):
"""Test suite for TRTLLM MLA backend with centralized configuration."""
def _merge_config(self, test_case):
"""Merge test case with default configuration."""
config = DEFAULT_CONFIG.copy()
config.update(test_case)
return config
def _create_model_components(self, config):
"""Create model runners, backends, and layer for testing."""
# Create model runners
model_runner_trtllm = MockModelRunner(config)
model_runner_reference = MockModelRunner(config)
# Create backends
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
# Create RadixAttention layer
layer = RadixAttention(
num_heads=config["num_attention_heads"],
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
scaling=model_runner_trtllm.model_config.scaling,
num_kv_heads=config["num_kv_heads"],
layer_id=config["layer_id"],
v_head_dim=config["v_head_dim"],
prefix="attn_mqa",
)
return (
model_runner_trtllm,
model_runner_reference,
trtllm_backend,
reference_backend,
layer,
)
def _create_qkv_tensors(self, batch_size, config):
"""Create Q, K, V tensors for testing."""
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
device = config["device"]
dtype = config["dtype"]
q = torch.randn(
(batch_size, config["num_attention_heads"], head_dim),
dtype=dtype,
device=device,
)
k = torch.randn(
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
)
v = torch.randn(
(batch_size, config["num_kv_heads"], config["v_head_dim"]),
dtype=dtype,
device=device,
)
return q, k, v
def _create_forward_batch(
self, batch_size, seq_lens, backend, model_runner, config
):
"""Create a forward batch for the given backend."""
fb = ForwardBatch(
batch_size=batch_size,
input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]),
out_cache_loc=torch.arange(batch_size, device=config["device"]),
seq_lens_sum=int(seq_lens.sum().item()),
forward_mode=ForwardMode.DECODE,
req_pool_indices=torch.arange(batch_size, device=config["device"]),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
attn_backend=backend,
)
fb.req_to_token_pool = model_runner.req_to_token_pool
fb.token_to_kv_pool = model_runner.token_to_kv_pool
return fb
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
"""Populate KV cache with identical data for both backends."""
torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache
for model_runner in model_runners:
torch.manual_seed(config["seed_cache"]) # Reset seed for each backend
for i in range(batch_size):
seq_len = int(seq_lens[i].item())
for token_idx in range(seq_len - 1):
# Create random K components for MLA
cache_k_nope = torch.randn(
(1, config["qk_nope_head_dim"]),
dtype=config["dtype"],
device=config["device"],
)
cache_k_rope = torch.randn(
(1, config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=config["device"],
)
# Calculate cache location
cache_loc = model_runner.req_to_token_pool.req_to_token[
i, token_idx
]
# Save to KV cache
model_runner.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc.unsqueeze(0),
cache_k_nope.squeeze(0),
cache_k_rope.squeeze(0),
)
def test_basic_functionality(self):
"""Test basic functionality with minimal setup."""
print(f"\nRunning basic functionality tests...")
for test_case in TEST_CASES["basic_functionality"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
model_runner_trtllm, _, trtllm_backend, _, layer = (
self._create_model_components(config)
)
# Create sequence lengths - properly handle different batch sizes
if batch_size == 2:
seq_lens = torch.tensor(
[max_seq_len, max_seq_len // 2], device=config["device"]
)
else:
# For larger batch sizes, create varied sequence lengths
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
max_seq_len // 2,
max_seq_len + 1,
(batch_size,),
device=config["device"],
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batch
fb = self._create_forward_batch(
batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config
)
trtllm_backend.init_forward_metadata(fb)
# Populate KV cache
self._populate_kv_cache(
batch_size, seq_lens, [model_runner_trtllm], layer, config
)
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
# Basic checks
expected_shape = (
batch_size,
config["num_attention_heads"] * config["v_head_dim"],
)
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, config["dtype"])
self.assertFalse(torch.isnan(output).any())
self.assertFalse(torch.isinf(output).any())
def test_decode_output_match(self):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
print(f"\nRunning decode output matching tests...")
for test_case in TEST_CASES["decode_output_match"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
(
model_runner_trtllm,
model_runner_reference,
trtllm_backend,
reference_backend,
layer,
) = self._create_model_components(config)
# Create identical sequence lengths for both backends
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batches with identical inputs
fb_trtllm = self._create_forward_batch(
batch_size,
seq_lens.clone(),
trtllm_backend,
model_runner_trtllm,
config,
)
fb_reference = self._create_forward_batch(
batch_size,
seq_lens.clone(),
reference_backend,
model_runner_reference,
config,
)
# Initialize metadata for both backends
trtllm_backend.init_forward_metadata(fb_trtllm)
reference_backend.init_forward_metadata(fb_reference)
# Populate both KV caches identically
self._populate_kv_cache(
batch_size,
seq_lens,
[model_runner_trtllm, model_runner_reference],
layer,
config,
)
# Create Q, K, V tensors for current decode step
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode on both backends
out_trtllm = trtllm_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_trtllm
)
out_reference = reference_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_reference
)
# Compare outputs
comparison_passed = compare_outputs(
out_trtllm, out_reference, tolerance=config["tolerance"]
)
self.assertTrue(
comparison_passed,
f"TRTLLM and Reference outputs differ beyond tolerance. "
f"Config: {test_case['name']}, "
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
)
def test_page_size_consistency(self):
"""Test output consistency across different page sizes."""
print(f"\nRunning page size consistency tests...")
for test_case in TEST_CASES["page_size_consistency"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Create sequence lengths
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len
# Create forward batch
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
backend.init_forward_metadata(fb)
# Populate KV cache
self._populate_kv_cache(
batch_size, seq_lens, [model_runner], layer, config
)
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
expected_shape = (
batch_size,
config["num_attention_heads"] * config["v_head_dim"],
)
self.assertEqual(
output.shape,
expected_shape,
f"Output shape mismatch: {output.shape} vs {expected_shape}",
)
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
def test_shape_sanity(self):
"""Smoke test decode across several configurations."""
print(f"\nRunning shape sanity tests...")
for test_case in TEST_CASES["shape_sanity_tests"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Random seq lens (ensure one matches max)
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
backend.init_forward_metadata(fb)
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
q = torch.randn(
(batch_size, config["num_attention_heads"], head_dim),
dtype=config["dtype"],
device=config["device"],
)
k = torch.randn(
(batch_size, config["num_kv_heads"], head_dim),
dtype=config["dtype"],
device=config["device"],
)
v = None
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
# Shape and sanity checks
expected_shape = (
batch_size,
config["num_attention_heads"] * config["v_head_dim"],
)
self.assertEqual(
output.shape,
expected_shape,
f"Output shape mismatch for {test_case['name']}",
)
self.assertEqual(output.dtype, config["dtype"])
self.assertEqual(output.device.type, "cuda")
self.assertFalse(
torch.isnan(output).any(),
f"Output contains NaN for {test_case['name']}",
)
self.assertFalse(
torch.isinf(output).any(),
f"Output contains Inf for {test_case['name']}",
)
def test_metadata_initialization(self):
"""Test TRTLLM MLA metadata initialization and structure."""
print(f"\nRunning metadata initialization tests...")
for test_case in TEST_CASES["metadata_tests"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Create varied sequence lengths
torch.manual_seed(config["seed_cache"])
if batch_size == 1:
seq_lens = torch.tensor([max_seq_len], device=config["device"])
else:
seq_lens = torch.randint(
max(1, max_seq_len // 4),
max_seq_len + 1,
(batch_size,),
device=config["device"],
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batch
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
# Initialize metadata
backend.init_forward_metadata(fb)
# Verify metadata exists
self.assertIsNotNone(backend.forward_metadata)
self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
# Test metadata structure
metadata = backend.forward_metadata
self.assertIsNotNone(
metadata.workspace, "Workspace should be allocated"
)
self.assertIsNotNone(
metadata.block_kv_indices, "Block KV indices should be created"
)
# Test workspace properties
self.assertEqual(metadata.workspace.device.type, "cuda")
self.assertEqual(metadata.workspace.dtype, torch.int8)
self.assertGreater(
metadata.workspace.numel(), 0, "Workspace should have non-zero size"
)
# Test block KV indices properties
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
# Verify block indices are valid (>= -1, since -1 is padding)
self.assertTrue(
(metadata.block_kv_indices >= -1).all(),
"All block indices should be >= -1 (with -1 as padding)",
)
def test_metadata_block_calculation(self):
"""Test block count calculation logic."""
print(f"\nRunning metadata block calculation tests...")
test_scenarios = [
{"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
{"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
{"seq_len": 33, "page_size": 32, "expected_min_blocks": 2},
{"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
{"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
]
for scenario in test_scenarios:
with self.subTest(scenario=scenario):
config = self._merge_config(
{
"batch_size": 1,
"max_seq_len": scenario["seq_len"],
"page_size": scenario["page_size"],
}
)
model_runner, _, backend, _, _ = self._create_model_components(config)
# Test internal block calculation
calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
# Should be at least the minimum required
self.assertGreaterEqual(
calculated_blocks,
scenario["expected_min_blocks"],
f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})",
)
# Should satisfy page_size constraint
total_tokens = calculated_blocks * scenario["page_size"]
self.assertGreaterEqual(
total_tokens,
scenario["seq_len"],
f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})",
)
# Should satisfy TRT-LLM and Triton constraints
trtllm_constraint = 128 // scenario["page_size"]
constraint_lcm = math.lcm(
trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
)
self.assertEqual(
calculated_blocks % constraint_lcm,
0,
f"Block count should be multiple of LCM of constraints ({constraint_lcm})",
)
def test_metadata_kv_indices_correctness(self):
"""Test KV indices creation and correctness."""
print(f"\nRunning KV indices correctness tests...")
for test_case in TEST_CASES["metadata_tests"][
:2
]: # Test subset for performance
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
model_runner, _, backend, _, layer = self._create_model_components(
config
)
# Create known sequence lengths
torch.manual_seed(config["seed_cache"])
if batch_size == 1:
seq_lens = torch.tensor([max_seq_len], device=config["device"])
else:
seq_lens = torch.randint(
max_seq_len // 2,
max_seq_len + 1,
(batch_size,),
device=config["device"],
)
fb = self._create_forward_batch(
batch_size, seq_lens, backend, model_runner, config
)
# Populate some KV cache to have valid indices
self._populate_kv_cache(
batch_size, seq_lens, [model_runner], layer, config
)
# Initialize metadata
backend.init_forward_metadata(fb)
metadata = backend.forward_metadata
# Verify KV indices structure
block_kv_indices = metadata.block_kv_indices
for i in range(batch_size):
seq_len = seq_lens[i].item()
expected_blocks = backend._calc_padded_blocks(seq_len)
# Count valid (non -1) indices for this sequence
valid_indices = (block_kv_indices[i] >= 0).sum().item()
# Should have at least enough blocks for the sequence
min_required_blocks = (seq_len + config["page_size"] - 1) // config[
"page_size"
]
self.assertGreaterEqual(
valid_indices,
min_required_blocks,
f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}",
)
# Verify indices are within valid range
valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
if len(valid_block_indices) > 0:
max_possible_blocks = (
model_runner.token_to_kv_pool.size // config["page_size"]
)
self.assertTrue(
(valid_block_indices < max_possible_blocks).all(),
f"All block indices should be < {max_possible_blocks}",
)
def test_metadata_cuda_graph_compatibility(self):
"""Test metadata compatibility with CUDA graph capture/replay."""
print(f"\nRunning CUDA graph compatibility tests...")
config = self._merge_config(
{"batch_size": 4, "max_seq_len": 64, "page_size": 32}
)
model_runner, _, backend, _, layer = self._create_model_components(config)
batch_size = config["batch_size"]
# Initialize CUDA graph state
backend.init_cuda_graph_state(
max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size
)
# Verify CUDA graph buffers are allocated
self.assertIsNotNone(backend.cuda_graph_kv_indices)
self.assertIsNotNone(backend.cuda_graph_workspace)
# Test capture metadata
seq_lens = torch.full(
(batch_size,), config["max_seq_len"], device=config["device"]
)
req_pool_indices = torch.arange(batch_size, device=config["device"])
backend.init_forward_metadata_capture_cuda_graph(
bs=batch_size,
num_tokens=batch_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=None,
)
# Verify capture metadata
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
self.assertIsNotNone(capture_metadata.workspace)
self.assertIsNotNone(capture_metadata.block_kv_indices)
# Test replay with different sequence lengths
new_seq_lens = torch.randint(
config["max_seq_len"] // 2,
config["max_seq_len"] + 1,
(batch_size,),
device=config["device"],
)
backend.init_forward_metadata_replay_cuda_graph(
bs=batch_size,
req_pool_indices=req_pool_indices,
seq_lens=new_seq_lens,
seq_lens_sum=new_seq_lens.sum().item(),
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=None,
seq_lens_cpu=new_seq_lens.cpu(),
)
# Verify replay updated the metadata
replay_metadata = backend.forward_metadata
self.assertIsNotNone(replay_metadata)
self.assertEqual(
replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
)
def test_metadata_consistency_across_calls(self):
"""Test metadata consistency across multiple forward calls."""
print(f"\nRunning metadata consistency tests...")
config = self._merge_config(
{"batch_size": 2, "max_seq_len": 64, "page_size": 32}
)
model_runner, _, backend, _, layer = self._create_model_components(config)
# First call
seq_lens_1 = torch.tensor([32, 48], device=config["device"])
fb_1 = self._create_forward_batch(
config["batch_size"], seq_lens_1, backend, model_runner, config
)
backend.init_forward_metadata(fb_1)
metadata_1 = backend.forward_metadata
# Second call with same sequence lengths
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
fb_2 = self._create_forward_batch(
config["batch_size"], seq_lens_2, backend, model_runner, config
)
backend.init_forward_metadata(fb_2)
metadata_2 = backend.forward_metadata
# Metadata structure should be consistent
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
self.assertEqual(
metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
)
# Third call with different sequence lengths
seq_lens_3 = torch.tensor([16, 64], device=config["device"])
fb_3 = self._create_forward_batch(
config["batch_size"], seq_lens_3, backend, model_runner, config
)
backend.init_forward_metadata(fb_3)
metadata_3 = backend.forward_metadata
# Should still have valid structure
self.assertIsNotNone(metadata_3.workspace)
self.assertIsNotNone(metadata_3.block_kv_indices)
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment