Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
434ac76a
Unverified
Commit
434ac76a
authored
Dec 10, 2025
by
Fadi Arafeh
Committed by
GitHub
Dec 10, 2025
Browse files
[cpu][ci] Add CPU Attention Tests for Neon Backend (#30347)
Signed-off-by:
Fadi Arafeh
<
fadi.arafeh@arm.com
>
parent
ed7af317
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
10 deletions
+63
-10
tests/kernels/attention/test_cpu_attn.py
tests/kernels/attention/test_cpu_attn.py
+63
-10
No files found.
tests/kernels/attention/test_cpu_attn.py
View file @
434ac76a
...
...
@@ -7,7 +7,8 @@ import math
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.v1.attention.backends.cpu_attn
import
_get_attn_isa
if
not
current_platform
.
is_cpu
():
pytest
.
skip
(
"skipping CPU-only tests"
,
allow_module_level
=
True
)
...
...
@@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
]
def
get_attn_isa
(
block_size
:
int
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
):
if
block_size
and
dtype
:
return
_get_attn_isa
(
dtype
,
block_size
)
else
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
ARM
:
return
"neon"
elif
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
():
return
"amx"
else
:
return
"vec"
# rand number generation takes too much time, cache rand tensors
@
functools
.
lru_cache
(
maxsize
=
128
,
typed
=
False
)
def
tensor_cache
(
...
...
@@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
96
,
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
QTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"neon"
])
@
pytest
.
mark
.
skipif
(
current_platform
.
get_cpu_architecture
()
!=
CpuArchEnum
.
ARM
,
reason
=
"Not an Arm CPU."
,
)
def
test_varlen_with_paged_kv_normal_neon
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
96
])
...
...
@@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
)
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
get_attn_isa
()])
def
test_varlen_with_paged_kv_softcap
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
...
...
@@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
)
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
get_attn_isa
()])
def
test_varlen_with_paged_kv_alibi
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
...
...
@@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
)
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
get_attn_isa
()])
def
test_varlen_with_paged_kv_sink
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment