Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5fe643fc
Unverified
Commit
5fe643fc
authored
Sep 12, 2025
by
Matthew Bonanni
Committed by
GitHub
Sep 12, 2025
Browse files
Add FLASHINFER_MLA to backend selector test (#24753)
Signed-off-by:
Matthew Bonanni
<
mbonanni001@gmail.com
>
parent
7ba32aa6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
19 deletions
+43
-19
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+41
-19
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+2
-0
No files found.
tests/kernels/attention/test_attention_selector.py
View file @
5fe643fc
...
@@ -22,7 +22,10 @@ def clear_cache():
...
@@ -22,7 +22,10 @@ def clear_cache():
# Define MLA and non-MLA backends separately
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS
=
{
DEVICE_MLA_BACKENDS
=
{
"cuda"
:
[
"TRITON_MLA"
,
"FLASHMLA"
,
"FLASH_ATTN_MLA"
,
"CUTLASS_MLA"
],
"cuda"
:
[
"TRITON_MLA"
,
"FLASHMLA"
,
"FLASHINFER_MLA"
,
"FLASH_ATTN_MLA"
,
"CUTLASS_MLA"
],
"hip"
:
[
"TRITON_MLA"
,
"ROCM_AITER_MLA"
],
"hip"
:
[
"TRITON_MLA"
,
"ROCM_AITER_MLA"
],
"cpu"
:
[],
"cpu"
:
[],
}
}
...
@@ -90,8 +93,8 @@ def test_env(
...
@@ -90,8 +93,8 @@ def test_env(
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
)
False
)
assert
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
assert
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
elif
device
==
"hip"
:
elif
device
==
"hip"
:
...
@@ -109,7 +112,7 @@ def test_env(
...
@@ -109,7 +112,7 @@ def test_env(
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
16
,
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -120,7 +123,7 @@ def test_env(
...
@@ -120,7 +123,7 @@ def test_env(
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
16
,
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -130,7 +133,7 @@ def test_env(
...
@@ -130,7 +133,7 @@ def test_env(
# Valid backend-block_size combination
# Valid backend-block_size combination
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -139,7 +142,7 @@ def test_env(
...
@@ -139,7 +142,7 @@ def test_env(
else
:
else
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -153,6 +156,8 @@ def test_env(
...
@@ -153,6 +156,8 @@ def test_env(
# CUDA MLA backend logic:
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# and Blackwell GPUs (SM 10.0), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
# - TRITON_MLA: fallback for other cases
...
@@ -169,12 +174,31 @@ def test_env(
...
@@ -169,12 +174,31 @@ def test_env(
else
:
else
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
expected
=
"CUTLASS_MLA_VLLM_V1"
expected
=
"CUTLASS_MLA_VLLM_V1"
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
elif
name
==
"FLASHINFER_MLA"
:
if
not
use_v1
:
# FlashInfer MLA only supported on V1 engine
pytest
.
skip
(
"FlashInfer MLA only supported on V1 engine"
)
elif
block_size
not
in
[
32
,
64
]:
# FlashInfer MLA only supports block_size 32 or 64
pytest
.
skip
(
"FlashInfer MLA only supports block_size 32 "
"or 64"
)
else
:
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
,
False
,
use_mla
=
use_mla
)
expected
=
"FLASHINFER_MLA"
assert
backend
.
get_name
()
==
expected
elif
name
==
"FLASHMLA"
:
elif
name
==
"FLASHMLA"
:
if
block_size
!=
64
:
if
block_size
!=
64
:
# FlashMLA only supports block_size == 64
# FlashMLA only supports block_size == 64
...
@@ -189,7 +213,7 @@ def test_env(
...
@@ -189,7 +213,7 @@ def test_env(
else
:
else
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -204,7 +228,7 @@ def test_env(
...
@@ -204,7 +228,7 @@ def test_env(
else
:
else
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -214,7 +238,7 @@ def test_env(
...
@@ -214,7 +238,7 @@ def test_env(
# TRITON_MLA or other fallback
# TRITON_MLA or other fallback
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -224,7 +248,7 @@ def test_env(
...
@@ -224,7 +248,7 @@ def test_env(
elif
name
==
"FLASHINFER"
:
elif
name
==
"FLASHINFER"
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -233,7 +257,7 @@ def test_env(
...
@@ -233,7 +257,7 @@ def test_env(
else
:
else
:
backend
=
get_attn_backend
(
32
,
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -243,7 +267,7 @@ def test_env(
...
@@ -243,7 +267,7 @@ def test_env(
if
use_v1
:
if
use_v1
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
None
,
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
...
@@ -269,15 +293,13 @@ def test_fp32_fallback(
...
@@ -269,15 +293,13 @@ def test_fp32_fallback(
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
torch
.
float32
,
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
,
False
)
16
,
False
)
assert
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
assert
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
elif
device
==
"cuda"
:
elif
device
==
"cuda"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
torch
.
float32
,
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
,
False
)
16
,
False
)
assert
(
backend
.
get_name
()
==
"FLEX_ATTENTION"
assert
(
backend
.
get_name
()
==
"FLEX_ATTENTION"
if
use_v1
else
"XFORMERS"
)
if
use_v1
else
"XFORMERS"
)
...
@@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
...
@@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
True
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
...
...
tests/v1/attention/utils.py
View file @
5fe643fc
...
@@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
...
@@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
_Backend
.
FLASH_ATTN_MLA
:
_Backend
.
FLASH_ATTN_MLA
:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
,
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
,
_Backend
.
FLASHINFER_MLA
:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
,
_Backend
.
TRITON_MLA_VLLM_V1
:
_Backend
.
TRITON_MLA_VLLM_V1
:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment