Unverified Commit f508cd3c authored by Faraz's avatar Faraz Committed by GitHub
Browse files

TRTLLM-MLA FP8 path (#8638)


Signed-off-by: default avatarFaraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
parent 44e86480
...@@ -60,6 +60,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti ...@@ -60,6 +60,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
``` ```
- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)
```bash
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```
- Ascend - Ascend
```bash ```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
......
...@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -287,38 +287,135 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
) )
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
def quantize_and_rope_for_fp8(
self,
q_nope: torch.Tensor,
q_rope: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
forward_batch: ForwardBatch,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Quantize and apply RoPE for FP8 attention path.
This function handles the FP8 quantization and RoPE application for MLA attention.
It takes separate query/key nope and rope components, applies RoPE to the rope parts,
quantizes all components to FP8, and merges the query components into a single tensor.
Args:
q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank]
- expected dtype: torch.bfloat16
q_rope: Query RoPE component [seq_len, num_heads, qk_rope_head_dim]
- expected dtype: torch.bfloat16
k_nope: Key no-position-encoding component [seq_len, num_heads, kv_lora_rank]
- expected dtype: torch.bfloat16
k_rope: Key RoPE component [seq_len, num_heads, qk_rope_head_dim]
- expected dtype: torch.bfloat16
forward_batch: Forward batch containing position information
cos_sin_cache: Precomputed cosine/sine cache for RoPE
- expected dtype: matches q_/k_ input dtype (torch.bfloat16)
is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation)
Returns:
tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8
- merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn
- k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn
- k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn
"""
attn_dtype = torch.float8_e4m3fn
q_len, num_heads = q_rope.shape[0], q_rope.shape[1]
# Allocate output tensors with FP8 dtype
# Query output will contain merged nope + rope components
q_out = q_rope.new_empty(
q_len,
num_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
dtype=attn_dtype,
)
# Key outputs maintain original shapes but with FP8 dtype
k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype)
k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype)
# Apply RoPE and quantize all components in a single fused kernel call
# This kernel handles:
# 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions
# 2. Quantization of all components to FP8 format
# 3. Output placement into pre-allocated tensors
flashinfer.rope.mla_rope_quantize_fp8(
q_rope=q_rope,
k_rope=k_rope,
q_nope=q_nope,
k_nope=k_nope,
cos_sin_cache=cos_sin_cache,
pos_ids=forward_batch.positions,
is_neox=is_neox,
quantize_dtype=attn_dtype,
# Output tensor slicing: q_out contains [nope_part, rope_part]
q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end
k_rope_out=k_rope_out,
q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning
k_nope_out=k_nope_out,
# Quantization scales (set to 1.0 for no additional scaling)
quant_scale_q=1.0,
quant_scale_kv=1.0,
)
return q_out, k_nope_out, k_rope_out
def forward_decode( def forward_decode(
self, self,
q: torch.Tensor, q: torch.Tensor, # q_nope
k: torch.Tensor, k: torch.Tensor, # k_nope
v: torch.Tensor, v: torch.Tensor, # not used in this backend
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run forward for decode using TRTLLM MLA kernel.""" """Run forward for decode using TRTLLM MLA kernel."""
merge_query = q_rope is not None
if self.data_type == torch.float8_e4m3fn:
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
assert all(
x is not None for x in [q_rope, k_rope, cos_sin_cache]
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
q, k, k_rope = self.quantize_and_rope_for_fp8(
q,
q_rope,
k.squeeze(1),
k_rope.squeeze(1),
forward_batch,
cos_sin_cache,
is_neox,
)
merge_query = False
# Save KV cache if requested # Save KV cache if requested
if k is not None and save_kv_cache: if save_kv_cache:
cache_loc = forward_batch.out_cache_loc assert (
if k_rope is not None: k is not None and k_rope is not None
forward_batch.token_to_kv_pool.set_mla_kv_buffer( ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
layer, cache_loc, k, k_rope forward_batch.token_to_kv_pool.set_mla_kv_buffer(
) layer, forward_batch.out_cache_loc, k, k_rope
elif v is not None: )
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
# Prepare query tensor inline # Prepare query tensor inline
if q_rope is not None: if merge_query:
# q contains NOPE part (v_head_dim) # For FP16 path, we merge the query and rope parts into a single tensor
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope_reshaped = q_rope.view( q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
) )
query = torch.cat([q_nope, q_rope_reshaped], dim=-1) query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else: else:
# q already has both parts # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query = q.view(-1, layer.tp_q_head_num, layer.head_dim) query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
...@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -327,9 +424,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Prepare KV cache inline # Prepare KV cache inline
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
# TRT-LLM expects single KV data with extra dimension
kv_cache = pages.unsqueeze(1)
# Get metadata # Get metadata
metadata = ( metadata = (
...@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -337,11 +432,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
or self.forward_metadata or self.forward_metadata
) )
# Scale computation for TRTLLM MLA kernel: # Scale computation for TRTLLM MLA kernel BMM1 operation:
# - BMM1 scale = q_scale * k_scale * softmax_scale # The final BMM1 scale is computed as: q_scale * k_scale * softmax_scale
# - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling # Scale components:
# - k_scale is read from model checkpoint if available # - q_scale: Query scaling factor (set to 1.0 for both FP16/FP8 paths)
# TODO: Change once fp8 path is supported # - k_scale: Key scaling factor from model checkpoint (defaults to 1.0 if not available)
# - softmax_scale: Attention softmax scaling = 1/sqrt(head_dim), pre-computed as layer.scaling
# This unified approach works for both FP16 and FP8 quantized attention paths.
q_scale = 1.0 q_scale = 1.0
k_scale = ( k_scale = (
layer.k_scale_float layer.k_scale_float
......
...@@ -1196,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1196,6 +1196,16 @@ class DeepseekV2AttentionMLA(nn.Module):
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
"""
Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
"""
return (
self.current_attention_backend == "trtllm_mla"
and forward_batch.forward_mode.is_decode_or_idle()
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
)
def forward_absorb_prepare( def forward_absorb_prepare(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -1275,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1275,7 +1285,9 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if not self._fuse_rope_for_trtllm_mla(forward_batch):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
...@@ -1288,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1288,8 +1300,20 @@ class DeepseekV2AttentionMLA(nn.Module):
or self.current_attention_backend == "cutlass_mla" or self.current_attention_backend == "cutlass_mla"
or self.current_attention_backend == "trtllm_mla" or self.current_attention_backend == "trtllm_mla"
): ):
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = {
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
"is_neox": self.rotary_emb.is_neox_style,
}
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe q_nope_out,
k_nope,
k_nope,
forward_batch,
q_rope=q_pe,
k_rope=k_pe,
**extra_args,
) )
else: else:
q = torch.cat([q_nope_out, q_pe], dim=-1) q = torch.cat([q_nope_out, q_pe], dim=-1)
......
...@@ -432,7 +432,10 @@ class ServerArgs: ...@@ -432,7 +432,10 @@ class ServerArgs:
) )
self.page_size = 128 self.page_size = 128
if self.attention_backend == "trtllm_mla": if (
self.attention_backend == "trtllm_mla"
or self.decode_attention_backend == "trtllm_mla"
):
if not is_sm100_supported(): if not is_sm100_supported():
raise ValueError( raise ValueError(
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
...@@ -443,11 +446,17 @@ class ServerArgs: ...@@ -443,11 +446,17 @@ class ServerArgs:
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64." f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
) )
self.page_size = 64 self.page_size = 64
if self.speculative_algorithm is not None: if self.speculative_algorithm is not None:
raise ValueError( raise ValueError(
"trtllm_mla backend does not support speculative decoding yet." "trtllm_mla backend does not support speculative decoding yet."
) )
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
raise ValueError(
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
)
if ( if (
self.attention_backend == "trtllm_mha" self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha" or self.decode_attention_backend == "trtllm_mha"
......
...@@ -43,6 +43,37 @@ DEFAULT_CONFIG = { ...@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
"layer_id": 0, "layer_id": 0,
} }
ROPE_BASE = 10000
ROPE_SCALING_CONFIG = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
"rope_type": "deepseek_yarn",
}
def build_rotary_emb(config, device=None):
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
dev = device or config["device"]
rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
rotary = get_rope_wrapper(
head_size=config["qk_rope_head_dim"],
rotary_dim=config["qk_rope_head_dim"],
max_position=config["context_len"],
base=ROPE_BASE,
rope_scaling=rope_scaling,
is_neox_style=False,
device=dev,
)
rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
return rotary
# Centralized test cases for different test scenarios # Centralized test cases for different test scenarios
TEST_CASES = { TEST_CASES = {
"basic_functionality": [ "basic_functionality": [
...@@ -63,18 +94,36 @@ TEST_CASES = { ...@@ -63,18 +94,36 @@ TEST_CASES = {
], ],
"decode_output_match": [ "decode_output_match": [
{ {
"name": "single", "name": "single_fp16",
"batch_size": 1, "batch_size": 1,
"max_seq_len": 64, "max_seq_len": 64,
"page_size": 32, "page_size": 32,
"description": "Single vs reference", "description": "Single FP16 vs reference",
}, },
{ {
"name": "batch", "name": "single_fp8",
"batch_size": 1,
"max_seq_len": 64,
"page_size": 64,
"tolerance": 1e-1,
"kv_cache_dtype": torch.float8_e4m3fn,
"description": "Single FP8 vs reference",
},
{
"name": "batch_fp16",
"batch_size": 32, "batch_size": 32,
"max_seq_len": 64, "max_seq_len": 64,
"page_size": 32, "page_size": 32,
"description": "Batch vs reference", "description": "Batch FP16 vs reference",
},
{
"name": "batch_fp8",
"batch_size": 32,
"max_seq_len": 64,
"page_size": 64,
"tolerance": 1e-1,
"kv_cache_dtype": torch.float8_e4m3fn,
"description": "Batch FP8 vs reference",
}, },
], ],
"page_size_consistency": [ "page_size_consistency": [
...@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
layer, layer,
) )
def _create_qkv_tensors(self, batch_size, config): def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
"""Create Q, K, V tensors for testing.""" """Create Q, K, V random tensors for given batch size with separate MLA components.
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
Args:
batch_size: Batch size.
config: Configuration dict with model dims and device.
dtype_override: Optional torch dtype to override config["dtype"].
Returns:
Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
"""
device = config["device"] device = config["device"]
dtype = config["dtype"] target_dtype = dtype_override or config["dtype"]
q = torch.randn( # Create separate nope and rope components for Q
(batch_size, config["num_attention_heads"], head_dim), q_nope = torch.randn(
dtype=dtype, (batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
dtype=config["dtype"],
device=device, device=device,
) )
k = torch.randn( q_rope = torch.randn(
(batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device (batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=device,
)
# Create separate nope and rope components for K
k_nope = torch.randn(
(batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
dtype=config["dtype"],
device=device,
)
k_rope = torch.randn(
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=device,
) )
# V tensor (unchanged)
v = torch.randn( v = torch.randn(
(batch_size, config["num_kv_heads"], config["v_head_dim"]), (batch_size, config["num_kv_heads"], config["v_head_dim"]),
dtype=dtype, dtype=config["dtype"],
device=device, device=device,
) )
return q, k, v
return q_nope, q_rope, k_nope, k_rope, v
def _create_forward_batch( def _create_forward_batch(
self, batch_size, seq_lens, backend, model_runner, config self, batch_size, seq_lens, backend, model_runner, config
...@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
) )
fb.req_to_token_pool = model_runner.req_to_token_pool fb.req_to_token_pool = model_runner.req_to_token_pool
fb.token_to_kv_pool = model_runner.token_to_kv_pool fb.token_to_kv_pool = model_runner.token_to_kv_pool
# Add position information for RoPE
fb.positions = torch.arange(batch_size, device=config["device"])
return fb return fb
def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config): def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
...@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
for token_idx in range(seq_len - 1): for token_idx in range(seq_len - 1):
# Create random K components for MLA # Create random K components for MLA
cache_k_nope = torch.randn( cache_k_nope = torch.randn(
(1, config["qk_nope_head_dim"]), (1, config["kv_lora_rank"]),
dtype=config["dtype"], dtype=config["dtype"],
device=config["device"], device=config["device"],
) )
...@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
batch_size, seq_lens, [model_runner_trtllm], layer, config batch_size, seq_lens, [model_runner_trtllm], layer, config
) )
# Create Q, K, V tensors # Create Q, K, V tensors with separate MLA components
torch.manual_seed(config["seed_qkv"]) torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config) q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
batch_size, config
)
# Run forward decode # Run forward decode with separate MLA components
output = trtllm_backend.forward_decode(q, k, v, layer, fb) output = trtllm_backend.forward_decode(
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
)
# Basic checks # Basic checks
expected_shape = ( expected_shape = (
...@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
config = self._merge_config(test_case) config = self._merge_config(test_case)
batch_size = config["batch_size"] batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"] max_seq_len = config["max_seq_len"]
use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
# Create components # Create components
( (
...@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
# Create Q, K, V tensors for current decode step # Create Q, K, V tensors for current decode step
torch.manual_seed(config["seed_qkv"]) torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config)
q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
self._create_qkv_tensors(batch_size, config)
)
q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
q_nope_ref.clone(),
q_rope_ref.clone(),
k_nope_ref.clone(),
k_rope_ref.clone(),
v_ref.clone(),
)
tolerance = config["tolerance"]
extra_args = {}
if use_fp8:
# TRT kernel applies RoPE + FP8 quantization internally
# pre-apply RoPE on the reference (FlashInfer) path here so
# both paths share the same rope params/cache while keeping
# the TRT path unrotated.
rotary_emb = build_rotary_emb(config)
q_rope_ref, k_rope_ref = rotary_emb(
fb_reference.positions, q_rope_ref, k_rope_ref
)
extra_args = {
"cos_sin_cache": rotary_emb.cos_sin_cache,
"is_neox": rotary_emb.is_neox_style,
}
dtype = q_rope_ref.dtype
q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
# Run forward decode on both backends # Run forward decode on both backends
out_trtllm = trtllm_backend.forward_decode( out_trtllm = trtllm_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_trtllm q_nope_trt,
k_nope_trt,
None,
layer,
fb_trtllm,
q_rope=q_rope_trt,
k_rope=k_rope_trt,
**extra_args,
) )
# Reference backend should also take separate components, not concatenated
out_reference = reference_backend.forward_decode( out_reference = reference_backend.forward_decode(
q.clone(), k.clone(), v.clone(), layer, fb_reference q_nope_ref,
k_nope_ref,
v_ref,
layer,
fb_reference,
q_rope=q_rope_ref,
k_rope=k_rope_ref,
) )
# Compare outputs # Compare outputs
comparison_passed = compare_outputs( comparison_passed = compare_outputs(
out_trtllm, out_reference, tolerance=config["tolerance"] out_trtllm, out_reference, tolerance=tolerance
) )
self.assertTrue( self.assertTrue(
...@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
batch_size, seq_lens, [model_runner], layer, config batch_size, seq_lens, [model_runner], layer, config
) )
# Create Q, K, V tensors # Create Q, K, V tensors with separate MLA components
torch.manual_seed(config["seed_qkv"]) torch.manual_seed(config["seed_qkv"])
q, k, v = self._create_qkv_tensors(batch_size, config) q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
batch_size, config
)
# Run forward decode # Run forward decode with separate MLA components
output = backend.forward_decode(q, k, v, layer, fb) output = backend.forward_decode(
q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
)
expected_shape = ( expected_shape = (
batch_size, batch_size,
...@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
) )
backend.init_forward_metadata(fb) backend.init_forward_metadata(fb)
# Create Q, K, V tensors # Create Q, K, V tensors with separate MLA components
torch.manual_seed(config["seed_qkv"]) torch.manual_seed(config["seed_qkv"])
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] q_nope = torch.randn(
q = torch.randn( (batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
(batch_size, config["num_attention_heads"], head_dim),
dtype=config["dtype"], dtype=config["dtype"],
device=config["device"], device=config["device"],
) )
k = torch.randn( k_nope = torch.randn(
(batch_size, config["num_kv_heads"], head_dim), (batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
dtype=config["dtype"], dtype=config["dtype"],
device=config["device"], device=config["device"],
) )
v = None q_rope = torch.randn(
(
batch_size,
config["num_attention_heads"],
config["qk_rope_head_dim"],
),
dtype=config["dtype"],
device=config["device"],
)
k_rope = torch.randn(
(batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
dtype=config["dtype"],
device=config["device"],
)
v = None # Test with None v
# Run forward decode # Run forward decode
output = backend.forward_decode(q, k, v, layer, fb) output = backend.forward_decode(
q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
)
# Shape and sanity checks # Shape and sanity checks
expected_shape = ( expected_shape = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment