Unverified Commit 6f858930 authored by Johnsonms's avatar Johnsonms Committed by GitHub
Browse files

[Bug] test_flashattn_mla_backend errors in Hopper #12487 (#12488)

parent 229256c5
......@@ -4,6 +4,7 @@ import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......@@ -19,6 +20,7 @@ class MockModelRunner:
attention_arch = AttentionArch.MLA
self.device = "cuda"
self.dtype = torch.float16
self.is_hybrid = False
context_len = 2048
self.model_config = type(
"ModelConfig",
......@@ -29,6 +31,18 @@ class MockModelRunner:
},
)
self.sliding_window_size = None
# Add server_args attribute
self.server_args = type(
"ServerArgs",
(),
{
"kv_cache_dtype": torch.float16,
"speculative_eagle_topk": None,
"speculative_num_draft_tokens": 0,
"enable_deterministic_inference": False,
},
)
self.kv_cache_dtype = self.server_args.kv_cache_dtype
batch_size = 160
# Create a proper req_to_token_pool with the req_to_token attribute
......@@ -49,7 +63,7 @@ class MockModelRunner:
self.token_to_kv_pool = MLATokenToKVPool(
size=max_total_num_tokens,
page_size=self.page_size,
dtype=self.dtype,
dtype=self.kv_cache_dtype,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
layer_num=1, # only consider layer=1 for unit test
......@@ -70,6 +84,15 @@ class MockReqToTokenPool:
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
class TestFlashAttentionMLABackend(CustomTestCase):
def setUp(self):
# MLA with different V headdim requires Hopper architecture (compute capability >= 9.0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 9:
self.skipTest(
f"MLA requires Hopper GPU (compute capability >= 9.0), "
f"but found compute capability {compute_capability[0]}.{compute_capability[1]}"
)
# Test parameters
self.batch_size = 2
self.seq_len = 360
......@@ -85,6 +108,7 @@ class TestFlashAttentionMLABackend(CustomTestCase):
# Initialize model runner and backend
self._init_model_runner()
self.backend = FlashAttentionBackend(self.model_runner)
self.ref_backend = TorchNativeAttnBackend(self.model_runner)
self.num_local_heads = 2
def _init_model_runner(self):
......@@ -92,7 +116,6 @@ class TestFlashAttentionMLABackend(CustomTestCase):
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
)
self.backend = FlashAttentionBackend(self.model_runner)
def _create_attention_layer(self):
"""Create attention layer for testing."""
......@@ -207,21 +230,29 @@ class TestFlashAttentionMLABackend(CustomTestCase):
if cache_len <= 0:
return
# Create constant values for the prefix cache for easy debugging
latent_cache = torch.ones(
# For MLA, create separate nope and rope caches
cache_k_nope = torch.ones(
self.batch_size * cache_len,
1, # latent cache has only one head in MQA
self.kv_lora_rank,
dtype=self.dtype,
device=self.device,
)
cache_k_rope = torch.ones(
self.batch_size * cache_len,
1, # latent cache has only one head in MQA
self.kv_lora_rank + self.qk_rope_head_dim,
self.qk_rope_head_dim,
dtype=self.dtype,
device=self.device,
)
# Set the prefix KV cache
forward_batch.token_to_kv_pool.set_kv_buffer(
# Set the prefix KV cache using MLA-specific method
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
torch.arange(self.batch_size * cache_len, device=self.device),
latent_cache,
None,
cache_k_nope,
cache_k_rope,
)
def _run_attention_test(self, mode, q_len, prefix_len=0):
......@@ -242,8 +273,18 @@ class TestFlashAttentionMLABackend(CustomTestCase):
kv_shape = (self.batch_size * q_len, self.qk_head_dim)
q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
# v is not used for mqa, all values passed in through k
k = kv_compressed.unsqueeze(1)
# For MLA, split kv_compressed into k_nope and k_rope
# k_nope has dimension kv_lora_rank, k_rope has dimension qk_rope_head_dim
k_nope = kv_compressed[:, : self.kv_lora_rank]
k_rope = kv_compressed[:, self.kv_lora_rank :]
# k_nope needs to be unsqueezed for the num_heads dimension
k = k_nope.unsqueeze(1)
# k_rope also needs to be unsqueezed
k_rope = k_rope.unsqueeze(1)
# v is not used for mqa
v = torch.randn((1), dtype=self.dtype, device=self.device)
self._setup_kv_cache(forward_batch, layer, prefix_len)
......@@ -256,9 +297,13 @@ class TestFlashAttentionMLABackend(CustomTestCase):
)
if mode == ForwardMode.EXTEND:
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
output = self.backend.forward_extend(
q, k, v, layer, forward_batch, k_rope=k_rope
)
else:
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
output = self.backend.forward_decode(
q, k, v, layer, forward_batch, k_rope=k_rope
)
self._verify_output(output, expected_shape)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment