Unverified Commit 6f858930 authored by Johnsonms's avatar Johnsonms Committed by GitHub
Browse files

[Bug] test_flashattn_mla_backend errors in Hopper #12487 (#12488)

parent 229256c5
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
...@@ -19,6 +20,7 @@ class MockModelRunner: ...@@ -19,6 +20,7 @@ class MockModelRunner:
attention_arch = AttentionArch.MLA attention_arch = AttentionArch.MLA
self.device = "cuda" self.device = "cuda"
self.dtype = torch.float16 self.dtype = torch.float16
self.is_hybrid = False
context_len = 2048 context_len = 2048
self.model_config = type( self.model_config = type(
"ModelConfig", "ModelConfig",
...@@ -29,6 +31,18 @@ class MockModelRunner: ...@@ -29,6 +31,18 @@ class MockModelRunner:
}, },
) )
self.sliding_window_size = None self.sliding_window_size = None
# Add server_args attribute
self.server_args = type(
"ServerArgs",
(),
{
"kv_cache_dtype": torch.float16,
"speculative_eagle_topk": None,
"speculative_num_draft_tokens": 0,
"enable_deterministic_inference": False,
},
)
self.kv_cache_dtype = self.server_args.kv_cache_dtype
batch_size = 160 batch_size = 160
# Create a proper req_to_token_pool with the req_to_token attribute # Create a proper req_to_token_pool with the req_to_token attribute
...@@ -49,7 +63,7 @@ class MockModelRunner: ...@@ -49,7 +63,7 @@ class MockModelRunner:
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
size=max_total_num_tokens, size=max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
dtype=self.dtype, dtype=self.kv_cache_dtype,
kv_lora_rank=kv_lora_rank, kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim, qk_rope_head_dim=qk_rope_head_dim,
layer_num=1, # only consider layer=1 for unit test layer_num=1, # only consider layer=1 for unit test
...@@ -70,6 +84,15 @@ class MockReqToTokenPool: ...@@ -70,6 +84,15 @@ class MockReqToTokenPool:
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
class TestFlashAttentionMLABackend(CustomTestCase): class TestFlashAttentionMLABackend(CustomTestCase):
def setUp(self): def setUp(self):
# MLA with different V headdim requires Hopper architecture (compute capability >= 9.0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 9:
self.skipTest(
f"MLA requires Hopper GPU (compute capability >= 9.0), "
f"but found compute capability {compute_capability[0]}.{compute_capability[1]}"
)
# Test parameters # Test parameters
self.batch_size = 2 self.batch_size = 2
self.seq_len = 360 self.seq_len = 360
...@@ -85,6 +108,7 @@ class TestFlashAttentionMLABackend(CustomTestCase): ...@@ -85,6 +108,7 @@ class TestFlashAttentionMLABackend(CustomTestCase):
# Initialize model runner and backend # Initialize model runner and backend
self._init_model_runner() self._init_model_runner()
self.backend = FlashAttentionBackend(self.model_runner) self.backend = FlashAttentionBackend(self.model_runner)
self.ref_backend = TorchNativeAttnBackend(self.model_runner)
self.num_local_heads = 2 self.num_local_heads = 2
def _init_model_runner(self): def _init_model_runner(self):
...@@ -92,7 +116,6 @@ class TestFlashAttentionMLABackend(CustomTestCase): ...@@ -92,7 +116,6 @@ class TestFlashAttentionMLABackend(CustomTestCase):
kv_lora_rank=self.kv_lora_rank, kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim,
) )
self.backend = FlashAttentionBackend(self.model_runner)
def _create_attention_layer(self): def _create_attention_layer(self):
"""Create attention layer for testing.""" """Create attention layer for testing."""
...@@ -207,21 +230,29 @@ class TestFlashAttentionMLABackend(CustomTestCase): ...@@ -207,21 +230,29 @@ class TestFlashAttentionMLABackend(CustomTestCase):
if cache_len <= 0: if cache_len <= 0:
return return
# Create constant values for the prefix cache for easy debugging # For MLA, create separate nope and rope caches
latent_cache = torch.ones( cache_k_nope = torch.ones(
self.batch_size * cache_len,
1, # latent cache has only one head in MQA
self.kv_lora_rank,
dtype=self.dtype,
device=self.device,
)
cache_k_rope = torch.ones(
self.batch_size * cache_len, self.batch_size * cache_len,
1, # latent cache has only one head in MQA 1, # latent cache has only one head in MQA
self.kv_lora_rank + self.qk_rope_head_dim, self.qk_rope_head_dim,
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
# Set the prefix KV cache # Set the prefix KV cache using MLA-specific method
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
torch.arange(self.batch_size * cache_len, device=self.device), torch.arange(self.batch_size * cache_len, device=self.device),
latent_cache, cache_k_nope,
None, cache_k_rope,
) )
def _run_attention_test(self, mode, q_len, prefix_len=0): def _run_attention_test(self, mode, q_len, prefix_len=0):
...@@ -242,8 +273,18 @@ class TestFlashAttentionMLABackend(CustomTestCase): ...@@ -242,8 +273,18 @@ class TestFlashAttentionMLABackend(CustomTestCase):
kv_shape = (self.batch_size * q_len, self.qk_head_dim) kv_shape = (self.batch_size * q_len, self.qk_head_dim)
q = torch.randn(q_shape, dtype=self.dtype, device=self.device) q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device) kv_compressed = torch.randn(kv_shape, dtype=self.dtype, device=self.device)
# v is not used for mqa, all values passed in through k
k = kv_compressed.unsqueeze(1) # For MLA, split kv_compressed into k_nope and k_rope
# k_nope has dimension kv_lora_rank, k_rope has dimension qk_rope_head_dim
k_nope = kv_compressed[:, : self.kv_lora_rank]
k_rope = kv_compressed[:, self.kv_lora_rank :]
# k_nope needs to be unsqueezed for the num_heads dimension
k = k_nope.unsqueeze(1)
# k_rope also needs to be unsqueezed
k_rope = k_rope.unsqueeze(1)
# v is not used for mqa
v = torch.randn((1), dtype=self.dtype, device=self.device) v = torch.randn((1), dtype=self.dtype, device=self.device)
self._setup_kv_cache(forward_batch, layer, prefix_len) self._setup_kv_cache(forward_batch, layer, prefix_len)
...@@ -256,9 +297,13 @@ class TestFlashAttentionMLABackend(CustomTestCase): ...@@ -256,9 +297,13 @@ class TestFlashAttentionMLABackend(CustomTestCase):
) )
if mode == ForwardMode.EXTEND: if mode == ForwardMode.EXTEND:
output = self.backend.forward_extend(q, k, v, layer, forward_batch) output = self.backend.forward_extend(
q, k, v, layer, forward_batch, k_rope=k_rope
)
else: else:
output = self.backend.forward_decode(q, k, v, layer, forward_batch) output = self.backend.forward_decode(
q, k, v, layer, forward_batch, k_rope=k_rope
)
self._verify_output(output, expected_shape) self._verify_output(output, expected_shape)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment