Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4dffc5e0
Unverified
Commit
4dffc5e0
authored
Feb 04, 2026
by
R3hankhan
Committed by
GitHub
Feb 03, 2026
Browse files
[CPU] Split attention dispatch by head_dim alignment (#32161)
Signed-off-by:
Rehan Khan
<
Rehan.Khan7@ibm.com
>
parent
e1bf04b6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
241 additions
and
107 deletions
+241
-107
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+13
-0
csrc/cpu/cpu_attn.cpp
csrc/cpu/cpu_attn.cpp
+21
-104
csrc/cpu/cpu_attn_amx.hpp
csrc/cpu/cpu_attn_amx.hpp
+1
-1
csrc/cpu/cpu_attn_neon.hpp
csrc/cpu/cpu_attn_neon.hpp
+1
-1
csrc/cpu/generate_cpu_attn_dispatch.py
csrc/cpu/generate_cpu_attn_dispatch.py
+203
-0
tests/kernels/attention/test_cpu_attn.py
tests/kernels/attention/test_cpu_attn.py
+2
-1
No files found.
cmake/cpu_extension.cmake
View file @
4dffc5e0
...
...
@@ -359,6 +359,19 @@ else()
add_compile_definitions
(
-DVLLM_NUMA_DISABLED
)
endif
()
#
# Generate CPU attention dispatch header
#
message
(
STATUS
"Generating CPU attention dispatch header"
)
execute_process
(
COMMAND
${
Python_EXECUTABLE
}
${
CMAKE_SOURCE_DIR
}
/csrc/cpu/generate_cpu_attn_dispatch.py
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/csrc/cpu
RESULT_VARIABLE GEN_RESULT
)
if
(
NOT GEN_RESULT EQUAL 0
)
message
(
FATAL_ERROR
"Failed to generate CPU attention dispatch header"
)
endif
()
#
# _C extension
#
...
...
csrc/cpu/cpu_attn.cpp
View file @
4dffc5e0
#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu_attn_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_attention::ISA::AMX: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
// NEON requires head_dim to be a multiple of 32
#define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
#endif // #ifdef __aarch64__
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
return __VA_ARGS__(); \
}
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
[&] { \
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
default: { \
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
std::to_string(HEAD_DIM)); \
} \
} \
}()
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
case cpu_attention::ISA::VEC16: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
} \
} \
}()
#include "cpu_attn_dispatch_generated.h"
torch
::
Tensor
get_scheduler_metadata
(
const
int64_t
num_req
,
const
int64_t
num_heads_q
,
...
...
@@ -122,16 +47,14 @@ torch::Tensor get_scheduler_metadata(
input
.
enable_kv_split
=
enable_kv_split
;
VLLM_DISPATCH_FLOATING_TYPES
(
dtype
,
"get_scheduler_metadata"
,
[
&
]()
{
CPU_ATTN_DISPATCH_CASE_HEADDIM
(
head_dim
,
[
&
]
{
CPU_ATTN_DISPATCH_IMPL
(
isa
,
[
&
]()
{
input
.
elem_size
=
sizeof
(
scalar_t
);
input
.
q_buffer_elem_size
=
sizeof
(
attn_impl
::
q_buffer_t
);
input
.
logits_buffer_elem_size
=
sizeof
(
attn_impl
::
logits_buffer_t
);
input
.
output_buffer_elem_size
=
sizeof
(
attn_impl
::
partial_output_buffer_t
);
input
.
max_num_q_per_iter
=
attn_impl
::
MaxQHeadNumPerIteration
;
input
.
kv_block_alignment
=
attn_impl
::
BlockSizeAlignment
;
});
CPU_ATTN_DISPATCH
(
head_dim
,
isa
,
[
&
]()
{
input
.
elem_size
=
sizeof
(
scalar_t
);
input
.
q_buffer_elem_size
=
sizeof
(
attn_impl
::
q_buffer_t
);
input
.
logits_buffer_elem_size
=
sizeof
(
attn_impl
::
logits_buffer_t
);
input
.
output_buffer_elem_size
=
sizeof
(
attn_impl
::
partial_output_buffer_t
);
input
.
max_num_q_per_iter
=
attn_impl
::
MaxQHeadNumPerIteration
;
input
.
kv_block_alignment
=
attn_impl
::
BlockSizeAlignment
;
});
});
...
...
@@ -184,18 +107,14 @@ void cpu_attn_reshape_and_cache(
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"cpu_attn_reshape_and_cache"
,
[
&
]()
{
CPU_ATTN_DISPATCH_CASE_HEADDIM
(
head_dim
,
[
&
]
{
CPU_ATTN_DISPATCH_IMPL
(
isa_tag
,
[
&
]()
{
attn_impl
::
reshape_and_cache
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
token_num
,
key_token_num_stride
,
value_token_num_stride
,
head_num
,
key_head_num_stride
,
value_head_num_stride
,
num_blocks
,
num_blocks_stride
,
cache_head_num_stride
,
block_size
,
block_size_stride
);
});
CPU_ATTN_DISPATCH
(
head_dim
,
isa_tag
,
[
&
]()
{
attn_impl
::
reshape_and_cache
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
token_num
,
key_token_num_stride
,
value_token_num_stride
,
head_num
,
key_head_num_stride
,
value_head_num_stride
,
num_blocks
,
num_blocks_stride
,
cache_head_num_stride
,
block_size
,
block_size_stride
);
});
});
}
...
...
@@ -257,12 +176,10 @@ void cpu_attention_with_kv_cache(
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"cpu_attention_with_kv_cache"
,
[
&
]()
{
CPU_ATTN_DISPATCH_CASE_HEADDIM
(
query
.
size
(
2
),
[
&
]
{
CPU_ATTN_DISPATCH_IMPL
(
input
.
metadata
->
isa
,
[
&
]()
{
TORCH_CHECK_EQ
(
input
.
block_size
%
attn_impl
::
BlockSizeAlignment
,
0
);
cpu_attention
::
AttentionMainLoop
<
attn_impl
>
mainloop
;
mainloop
(
&
input
);
});
CPU_ATTN_DISPATCH
(
query
.
size
(
2
),
input
.
metadata
->
isa
,
[
&
]()
{
TORCH_CHECK_EQ
(
input
.
block_size
%
attn_impl
::
BlockSizeAlignment
,
0
);
cpu_attention
::
AttentionMainLoop
<
attn_impl
>
mainloop
;
mainloop
(
&
input
);
});
});
}
csrc/cpu/cpu_attn_amx.hpp
View file @
4dffc5e0
...
...
@@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
const
int32_t
q_heads_per_kv
,
const
int64_t
q_num_stride
,
const
int64_t
q_head_stride
,
const
float
scale
)
{
constexpr
int64_t
bytes_per_head
=
head_dim
*
sizeof
(
scalar_t
);
//
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
static_assert
(
bytes_per_head
%
AMX_TILE_ROW_BYTES
==
0
);
constexpr
int64_t
head_size_block_num
=
bytes_per_head
/
AMX_TILE_ROW_BYTES
;
constexpr
int64_t
head_elem_num_pre_block
=
AMX_TILE_ROW_BYTES
/
sizeof
(
scalar_t
);
...
...
csrc/cpu/cpu_attn_neon.hpp
View file @
4dffc5e0
...
...
@@ -264,7 +264,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
constexpr
static
ISA
ISAType
=
ISA
::
NEON
;
constexpr
static
bool
scale_on_logits
=
false
;
// apply scale on q_buffer
//
static_assert(HeadDim % HeadDimAlignment == 0);
static_assert
(
HeadDim
%
HeadDimAlignment
==
0
);
// the gemm micro kernel is Mx8
static_assert
(
HeadDimAlignment
%
8
==
0
);
static_assert
(
BlockSizeAlignment
%
8
==
0
);
...
...
csrc/cpu/generate_cpu_attn_dispatch.py
0 → 100644
View file @
4dffc5e0
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Generate CPU attention dispatch switch cases and kernel instantiations.
"""
import
os
# Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
# Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16
=
[
80
,
112
]
# ISA types
ISA_TYPES
=
{
"AMX"
:
0
,
"VEC"
:
1
,
"VEC16"
:
2
,
"NEON"
:
3
,
}
# ISAs supported for head_dims divisible by 32
ISA_FOR_32
=
[
"AMX"
,
"NEON"
,
"VEC"
,
"VEC16"
]
# ISAs supported for head_dims divisible by 16 only
ISA_FOR_16
=
[
"VEC16"
]
def
encode_params
(
head_dim
:
int
,
isa_type
:
str
)
->
int
:
"""Encode head_dim and ISA type into a single int64_t."""
isa_val
=
ISA_TYPES
[
isa_type
]
# Encoding: (head_dim << 8) | isa_type
# This allows head_dim up to 2^56 - 1 and 256 ISA types
return
(
head_dim
<<
8
)
|
isa_val
def
generate_cases_for_isa_group
(
isa_list
:
list
[
str
])
->
str
:
"""Generate switch cases for a specific ISA group."""
cases
=
[]
# Generate cases for head_dims divisible by 32
for
head_dim
in
HEAD_DIMS_32
:
for
isa
in
isa_list
:
if
isa
not
in
ISA_FOR_32
:
continue
encoded
=
encode_params
(
head_dim
,
isa
)
case_str
=
(
f
""" case
{
encoded
}
LL: {{ """
f
"""/* head_dim=
{
head_dim
}
, isa=
{
isa
}
*/
\\
"""
f
"""
constexpr size_t head_dim =
{
head_dim
}
;
\\
"""
f
"""
using attn_impl = cpu_attention::AttentionImpl<"""
f
"""cpu_attention::ISA::
{
isa
}
,
\\
"""
f
"""
"""
f
"""scalar_t, head_dim>;
\\
"""
f
"""
return __VA_ARGS__();
\\
"""
f
"""
}}
\\
"""
)
cases
.
append
(
case_str
)
# Generate cases for head_dims divisible by 16 only
for
head_dim
in
HEAD_DIMS_16
:
for
isa
in
isa_list
:
encoded
=
encode_params
(
head_dim
,
isa
)
case_str
=
(
f
""" case
{
encoded
}
LL: {{ """
f
"""/* head_dim=
{
head_dim
}
, isa=
{
isa
}
"""
f
"""(using VEC16) */
\\
"""
f
"""
constexpr size_t head_dim =
{
head_dim
}
;
\\
"""
f
"""
using attn_impl = cpu_attention::AttentionImpl<"""
f
"""cpu_attention::ISA::VEC16,
\\
"""
f
"""
"""
f
"""scalar_t, head_dim>;
\\
"""
f
"""
return __VA_ARGS__();
\\
"""
f
"""
}}
\\
"""
)
cases
.
append
(
case_str
)
return
"
\n
"
.
join
(
cases
)
def
generate_helper_function
()
->
str
:
"""Generate helper function to encode parameters."""
return
"""
inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) {
return (head_dim << 8) | static_cast<int64_t>(isa);
}
"""
def
generate_header_file
()
->
str
:
"""Generate the complete header file content."""
header
=
"""// auto generated by generate_cpu_attn_dispatch.py
// clang-format off
#ifndef CPU_ATTN_DISPATCH_GENERATED_H
#define CPU_ATTN_DISPATCH_GENERATED_H
#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu_attn_amx.hpp"
#endif
#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
#endif
"""
header
+=
generate_helper_function
()
# Generate dispatch macro with conditional compilation for different ISA sets
header
+=
"""
// Dispatch macro using encoded parameters
"""
# x86_64 with AMX
header
+=
"""#if defined(CPU_CAPABILITY_AMXBF16)
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...)
\\
[&] {
\\
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE);
\\
switch (encoded_params) {
\\
"""
header
+=
generate_cases_for_isa_group
([
"AMX"
,
"VEC"
,
"VEC16"
])
header
+=
"""
default: {
\\
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" +
\\
std::to_string(HEAD_DIM) + " isa=" +
\\
std::to_string(static_cast<int>(ISA_TYPE)));
\\
}
\\
}
\\
}()
"""
# ARM64 with NEON
header
+=
"""#elif defined(__aarch64__)
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...)
\\
[&] {
\\
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE);
\\
switch (encoded_params) {
\\
"""
header
+=
generate_cases_for_isa_group
([
"NEON"
,
"VEC"
,
"VEC16"
])
header
+=
"""
default: {
\\
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" +
\\
std::to_string(HEAD_DIM) + " isa=" +
\\
std::to_string(static_cast<int>(ISA_TYPE)));
\\
}
\\
}
\\
}()
"""
# Fallback: VEC and VEC16 only
header
+=
"""#else
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...)
\\
[&] {
\\
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE);
\\
switch (encoded_params) {
\\
"""
header
+=
generate_cases_for_isa_group
([
"VEC"
,
"VEC16"
])
header
+=
"""
default: {
\\
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" +
\\
std::to_string(HEAD_DIM) + " isa=" +
\\
std::to_string(static_cast<int>(ISA_TYPE)));
\\
}
\\
}
\\
}()
#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ */
#endif // CPU_ATTN_DISPATCH_GENERATED_H
"""
return
header
def
main
():
output_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"cpu_attn_dispatch_generated.h"
)
with
open
(
output_path
,
"w"
)
as
f
:
f
.
write
(
generate_header_file
())
if
__name__
==
"__main__"
:
main
()
tests/kernels/attention/test_cpu_attn.py
View file @
4dffc5e0
...
...
@@ -26,6 +26,7 @@ NUM_HEADS = [
(
9
,
3
),
]
HEAD_SIZES
=
[
96
,
128
]
HEAD_SIZES_VEC16
=
[
96
,
80
,
112
,
128
]
QTYPES
=
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
]
SLIDING_WINDOWS
=
[
None
,
256
]
NUM_BLOCKS
=
[
...
...
@@ -432,7 +433,7 @@ def test_varlen_with_paged_kv_normal_amx(
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
_VEC16
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
48
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment