Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6f858930
Unverified
Commit
6f858930
authored
Nov 01, 2025
by
Johnsonms
Committed by
GitHub
Nov 01, 2025
Browse files
[Bug] test_flashattn_mla_backend errors in Hopper #12487 (#12488)
parent
229256c5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
13 deletions
+58
-13
python/sglang/test/attention/test_flashattn_mla_backend.py
python/sglang/test/attention/test_flashattn_mla_backend.py
+58
-13
No files found.
python/sglang/test/attention/test_flashattn_mla_backend.py
View file @
6f858930
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
MLATokenToKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -19,6 +20,7 @@ class MockModelRunner:
...
@@ -19,6 +20,7 @@ class MockModelRunner:
attention_arch
=
AttentionArch
.
MLA
attention_arch
=
AttentionArch
.
MLA
self
.
device
=
"cuda"
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
float16
self
.
dtype
=
torch
.
float16
self
.
is_hybrid
=
False
context_len
=
2048
context_len
=
2048
self
.
model_config
=
type
(
self
.
model_config
=
type
(
"ModelConfig"
,
"ModelConfig"
,
...
@@ -29,6 +31,18 @@ class MockModelRunner:
...
@@ -29,6 +31,18 @@ class MockModelRunner:
},
},
)
)
self
.
sliding_window_size
=
None
self
.
sliding_window_size
=
None
# Add server_args attribute
self
.
server_args
=
type
(
"ServerArgs"
,
(),
{
"kv_cache_dtype"
:
torch
.
float16
,
"speculative_eagle_topk"
:
None
,
"speculative_num_draft_tokens"
:
0
,
"enable_deterministic_inference"
:
False
,
},
)
self
.
kv_cache_dtype
=
self
.
server_args
.
kv_cache_dtype
batch_size
=
160
batch_size
=
160
# Create a proper req_to_token_pool with the req_to_token attribute
# Create a proper req_to_token_pool with the req_to_token attribute
...
@@ -49,7 +63,7 @@ class MockModelRunner:
...
@@ -49,7 +63,7 @@ class MockModelRunner:
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
size
=
max_total_num_tokens
,
size
=
max_total_num_tokens
,
page_size
=
self
.
page_size
,
page_size
=
self
.
page_size
,
dtype
=
self
.
dtype
,
dtype
=
self
.
kv_cache_
dtype
,
kv_lora_rank
=
kv_lora_rank
,
kv_lora_rank
=
kv_lora_rank
,
qk_rope_head_dim
=
qk_rope_head_dim
,
qk_rope_head_dim
=
qk_rope_head_dim
,
layer_num
=
1
,
# only consider layer=1 for unit test
layer_num
=
1
,
# only consider layer=1 for unit test
...
@@ -70,6 +84,15 @@ class MockReqToTokenPool:
...
@@ -70,6 +84,15 @@ class MockReqToTokenPool:
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"Test requires CUDA"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"Test requires CUDA"
)
class
TestFlashAttentionMLABackend
(
CustomTestCase
):
class
TestFlashAttentionMLABackend
(
CustomTestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
# MLA with different V headdim requires Hopper architecture (compute capability >= 9.0)
if
torch
.
cuda
.
is_available
():
compute_capability
=
torch
.
cuda
.
get_device_capability
()
if
compute_capability
[
0
]
<
9
:
self
.
skipTest
(
f
"MLA requires Hopper GPU (compute capability >= 9.0), "
f
"but found compute capability
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
"
)
# Test parameters
# Test parameters
self
.
batch_size
=
2
self
.
batch_size
=
2
self
.
seq_len
=
360
self
.
seq_len
=
360
...
@@ -85,6 +108,7 @@ class TestFlashAttentionMLABackend(CustomTestCase):
...
@@ -85,6 +108,7 @@ class TestFlashAttentionMLABackend(CustomTestCase):
# Initialize model runner and backend
# Initialize model runner and backend
self
.
_init_model_runner
()
self
.
_init_model_runner
()
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
self
.
ref_backend
=
TorchNativeAttnBackend
(
self
.
model_runner
)
self
.
num_local_heads
=
2
self
.
num_local_heads
=
2
def
_init_model_runner
(
self
):
def
_init_model_runner
(
self
):
...
@@ -92,7 +116,6 @@ class TestFlashAttentionMLABackend(CustomTestCase):
...
@@ -92,7 +116,6 @@ class TestFlashAttentionMLABackend(CustomTestCase):
kv_lora_rank
=
self
.
kv_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
)
)
self
.
backend
=
FlashAttentionBackend
(
self
.
model_runner
)
def
_create_attention_layer
(
self
):
def
_create_attention_layer
(
self
):
"""Create attention layer for testing."""
"""Create attention layer for testing."""
...
@@ -207,21 +230,29 @@ class TestFlashAttentionMLABackend(CustomTestCase):
...
@@ -207,21 +230,29 @@ class TestFlashAttentionMLABackend(CustomTestCase):
if
cache_len
<=
0
:
if
cache_len
<=
0
:
return
return
# Create constant values for the prefix cache for easy debugging
# For MLA, create separate nope and rope caches
latent_cache
=
torch
.
ones
(
cache_k_nope
=
torch
.
ones
(
self
.
batch_size
*
cache_len
,
1
,
# latent cache has only one head in MQA
self
.
kv_lora_rank
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
cache_k_rope
=
torch
.
ones
(
self
.
batch_size
*
cache_len
,
self
.
batch_size
*
cache_len
,
1
,
# latent cache has only one head in MQA
1
,
# latent cache has only one head in MQA
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
qk_rope_head_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
# Set the prefix KV cache
# Set the prefix KV cache
using MLA-specific method
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_
mla_
kv_buffer
(
layer
,
layer
,
torch
.
arange
(
self
.
batch_size
*
cache_len
,
device
=
self
.
device
),
torch
.
arange
(
self
.
batch_size
*
cache_len
,
device
=
self
.
device
),
latent_cach
e
,
cache_k_nop
e
,
Non
e
,
cache_k_rop
e
,
)
)
def
_run_attention_test
(
self
,
mode
,
q_len
,
prefix_len
=
0
):
def
_run_attention_test
(
self
,
mode
,
q_len
,
prefix_len
=
0
):
...
@@ -242,8 +273,18 @@ class TestFlashAttentionMLABackend(CustomTestCase):
...
@@ -242,8 +273,18 @@ class TestFlashAttentionMLABackend(CustomTestCase):
kv_shape
=
(
self
.
batch_size
*
q_len
,
self
.
qk_head_dim
)
kv_shape
=
(
self
.
batch_size
*
q_len
,
self
.
qk_head_dim
)
q
=
torch
.
randn
(
q_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
q
=
torch
.
randn
(
q_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
kv_compressed
=
torch
.
randn
(
kv_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
kv_compressed
=
torch
.
randn
(
kv_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
# v is not used for mqa, all values passed in through k
k
=
kv_compressed
.
unsqueeze
(
1
)
# For MLA, split kv_compressed into k_nope and k_rope
# k_nope has dimension kv_lora_rank, k_rope has dimension qk_rope_head_dim
k_nope
=
kv_compressed
[:,
:
self
.
kv_lora_rank
]
k_rope
=
kv_compressed
[:,
self
.
kv_lora_rank
:]
# k_nope needs to be unsqueezed for the num_heads dimension
k
=
k_nope
.
unsqueeze
(
1
)
# k_rope also needs to be unsqueezed
k_rope
=
k_rope
.
unsqueeze
(
1
)
# v is not used for mqa
v
=
torch
.
randn
((
1
),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
v
=
torch
.
randn
((
1
),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
_setup_kv_cache
(
forward_batch
,
layer
,
prefix_len
)
self
.
_setup_kv_cache
(
forward_batch
,
layer
,
prefix_len
)
...
@@ -256,9 +297,13 @@ class TestFlashAttentionMLABackend(CustomTestCase):
...
@@ -256,9 +297,13 @@ class TestFlashAttentionMLABackend(CustomTestCase):
)
)
if
mode
==
ForwardMode
.
EXTEND
:
if
mode
==
ForwardMode
.
EXTEND
:
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
output
=
self
.
backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
k_rope
=
k_rope
)
else
:
else
:
output
=
self
.
backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
output
=
self
.
backend
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
k_rope
=
k_rope
)
self
.
_verify_output
(
output
,
expected_shape
)
self
.
_verify_output
(
output
,
expected_shape
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment