Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e27663b
Unverified
Commit
8e27663b
authored
Jan 09, 2026
by
R3hankhan
Committed by
GitHub
Jan 09, 2026
Browse files
[CPU] Add head sizes 80 and 112 with vec16 fallback (#31968)
Signed-off-by:
Rehan Khan
<
Rehan.Khan7@ibm.com
>
parent
7cdf7e2f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
5 deletions
+12
-5
csrc/cpu/cpu_attn.cpp
csrc/cpu/cpu_attn.cpp
+3
-0
csrc/cpu/cpu_attn_amx.hpp
csrc/cpu/cpu_attn_amx.hpp
+1
-1
csrc/cpu/cpu_attn_neon.hpp
csrc/cpu/cpu_attn_neon.hpp
+1
-1
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+7
-3
No files found.
csrc/cpu/cpu_attn.cpp
View file @
8e27663b
...
...
@@ -15,6 +15,7 @@
#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
// NEON requires head_dim to be a multiple of 32
#define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
...
...
@@ -36,7 +37,9 @@
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
...
...
csrc/cpu/cpu_attn_amx.hpp
View file @
8e27663b
...
...
@@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
const
int32_t
q_heads_per_kv
,
const
int64_t
q_num_stride
,
const
int64_t
q_head_stride
,
const
float
scale
)
{
constexpr
int64_t
bytes_per_head
=
head_dim
*
sizeof
(
scalar_t
);
static_assert
(
bytes_per_head
%
AMX_TILE_ROW_BYTES
==
0
);
//
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
constexpr
int64_t
head_size_block_num
=
bytes_per_head
/
AMX_TILE_ROW_BYTES
;
constexpr
int64_t
head_elem_num_pre_block
=
AMX_TILE_ROW_BYTES
/
sizeof
(
scalar_t
);
...
...
csrc/cpu/cpu_attn_neon.hpp
View file @
8e27663b
...
...
@@ -264,7 +264,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
constexpr
static
ISA
ISAType
=
ISA
::
NEON
;
constexpr
static
bool
scale_on_logits
=
false
;
// apply scale on q_buffer
static_assert
(
HeadDim
%
HeadDimAlignment
==
0
);
//
static_assert(HeadDim % HeadDimAlignment == 0);
// the gemm micro kernel is Mx8
static_assert
(
HeadDimAlignment
%
8
==
0
);
static_assert
(
BlockSizeAlignment
%
8
==
0
);
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
8e27663b
...
...
@@ -42,7 +42,7 @@ class CPUAttentionBackend(AttentionBackend):
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
return
[
32
,
64
,
80
,
96
,
112
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_name
()
->
str
:
...
...
@@ -137,7 +137,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
if
self
.
window_size
is
None
:
self
.
window_size
=
-
1
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
isa
=
_get_attn_isa
(
self
.
dtype
,
self
.
block_size
)
self
.
isa
=
_get_attn_isa
(
self
.
dtype
,
self
.
block_size
,
self
.
head_dim
)
self
.
is_cross_attention
=
isinstance
(
kv_cache_spec
,
CrossAttentionSpec
)
def
build
(
...
...
@@ -484,7 +484,11 @@ def _make_sliding_window_bias(
return
attn_biases
def
_get_attn_isa
(
dtype
:
torch
.
dtype
,
block_size
:
int
)
->
str
:
def
_get_attn_isa
(
dtype
:
torch
.
dtype
,
block_size
:
int
,
head_size
:
int
|
None
=
None
)
->
str
:
if
head_size
is
not
None
and
head_size
%
32
!=
0
and
head_size
%
16
==
0
:
return
"vec16"
supports_amx
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
if
supports_amx
and
dtype
in
(
torch
.
bfloat16
,)
and
block_size
%
32
==
0
:
return
"amx"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment