Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1363e3d6
Unverified
Commit
1363e3d6
authored
Feb 06, 2026
by
Gassan Salama
Committed by
GitHub
Feb 06, 2026
Browse files
[cpu][performance] CPU Paged Attention NEON BFMMLA BF16 Implementation (#32263)
Signed-off-by:
Gassan
<
gassan.salama@arm.com
>
parent
96552566
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
704 additions
and
4 deletions
+704
-4
csrc/cpu/cpu_attn_impl.hpp
csrc/cpu/cpu_attn_impl.hpp
+2
-1
csrc/cpu/cpu_attn_neon.hpp
csrc/cpu/cpu_attn_neon.hpp
+17
-2
csrc/cpu/cpu_attn_neon_bfmmla.hpp
csrc/cpu/cpu_attn_neon_bfmmla.hpp
+682
-0
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+3
-1
No files found.
csrc/cpu/cpu_attn_impl.hpp
View file @
1363e3d6
...
@@ -1107,7 +1107,8 @@ class AttentionMainLoop {
...
@@ -1107,7 +1107,8 @@ class AttentionMainLoop {
if
(
sliding_window_left
!=
-
1
)
{
if
(
sliding_window_left
!=
-
1
)
{
pos
=
std
::
max
(
pos
,
curr_token_pos
-
sliding_window_left
);
pos
=
std
::
max
(
pos
,
curr_token_pos
-
sliding_window_left
);
}
}
return
pos
;
// Clamp to tile end to avoid OOB when window starts past the tile
return
std
::
min
(
pos
,
kv_tile_end_pos
);
}();
}();
int32_t
right_kv_pos
=
[
&
]()
{
int32_t
right_kv_pos
=
[
&
]()
{
...
...
csrc/cpu/cpu_attn_neon.hpp
View file @
1363e3d6
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
#include "cpu_attn_impl.hpp"
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <arm_neon.h>
#include <type_traits>
#include <type_traits>
#ifdef ARM_BF16_SUPPORT
#include "cpu_attn_neon_bfmmla.hpp"
#endif
namespace
cpu_attention
{
namespace
cpu_attention
{
namespace
{
namespace
{
...
@@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
...
@@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
#endif
#endif
}
}
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with
NEON
FMLAs
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with
ASIMD
FMLAs
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #FMLAs = (K // 4) * (4 * 2 * M)
// #FMLAs = (K // 4) * (4 * 2 * M)
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
...
@@ -381,6 +384,18 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
...
@@ -381,6 +384,18 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
}
}
}
}
};
};
#ifdef ARM_BF16_SUPPORT
// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment.
template
<
int64_t
head_dim
>
class
AttentionImpl
<
ISA
::
NEON
,
c10
::
BFloat16
,
head_dim
>
:
public
AttentionImplNEONBFMMLA
<
BLOCK_SIZE_ALIGNMENT
,
ISA
::
NEON
,
head_dim
>
{};
#endif
}
// namespace cpu_attention
}
// namespace cpu_attention
#endif // #ifndef CPU_ATTN_NEON_HPP
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif // #ifndef CPU_ATTN_ASIMD_HPP
csrc/cpu/cpu_attn_neon_bfmmla.hpp
0 → 100644
View file @
1363e3d6
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#ifndef CPU_ATTN_NEON_BFMMLA_HPP
#define CPU_ATTN_NEON_BFMMLA_HPP
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <cstdint>
#include <vector>
namespace
cpu_attention
{
namespace
{
// BFMMLA tile dimensions
constexpr
int32_t
TILE_ROWS
=
2
;
// M dimension
constexpr
int32_t
TILE_K
=
4
;
// K reduction
constexpr
int32_t
TILE_COLS
=
2
;
// N dimension (column-pair)
// Derived constants
constexpr
int32_t
OUTPUT_COLS_PER_BLOCK
=
8
;
// 4 column-pairs
constexpr
int32_t
K_TOKENS_PER_GROUP
=
8
;
// Tokens grouped in K cache
constexpr
int32_t
V_TOKENS_PER_ROW_BLOCK
=
4
;
// Tokens per V cache row block
constexpr
int32_t
K_INNER_STRIDE
=
K_TOKENS_PER_GROUP
*
TILE_K
;
constexpr
int32_t
V_INNER_STRIDE
=
V_TOKENS_PER_ROW_BLOCK
*
TILE_COLS
;
constexpr
int32_t
PACK_ELEMENTS_PER_K_CHUNK
=
TILE_ROWS
*
TILE_K
;
// A packing
// Matrix Packing and Accumulator
// Reshape two rows of Q into BFMMLA-friendly interleaved
// Input: row0 = [a0,a1,a2,a3], row1 = [b0,b1,b2,b3]
// Output: [a0,a1,a2,a3,b0,b1,b2,b3, a4,a5,a6,a7,b4,b5,b6,b7]
// For K tail (K % TILE_K != 0): pads with zeros to complete the final chunk
FORCE_INLINE
void
reshape_Q_2xK_for_bfmmla
(
const
c10
::
BFloat16
*
__restrict
r0
,
const
c10
::
BFloat16
*
__restrict
r1
,
c10
::
BFloat16
*
__restrict
dst
,
int32_t
K
)
{
const
uint16_t
*
s0
=
reinterpret_cast
<
const
uint16_t
*>
(
r0
);
const
uint16_t
*
s1
=
reinterpret_cast
<
const
uint16_t
*>
(
r1
);
uint16_t
*
d
=
reinterpret_cast
<
uint16_t
*>
(
dst
);
// Process TILE_K elements at a time (PACK_ELEMENTS_PER_K_CHUNK output)
int32_t
k
=
0
;
for
(;
k
+
TILE_K
<=
K
;
k
+=
TILE_K
,
d
+=
PACK_ELEMENTS_PER_K_CHUNK
)
{
vst1q_u16
(
d
,
vcombine_u16
(
vld1_u16
(
s0
+
k
),
vld1_u16
(
s1
+
k
)));
}
// Handle K tail: pack remaining elements with zero-padding
const
int32_t
tail
=
K
-
k
;
if
(
tail
>
0
)
{
// Pack remaining tail elements: [r0[k..k+tail-1], pad, r1[k..k+tail-1],
// pad]
for
(
int32_t
t
=
0
;
t
<
tail
;
++
t
)
{
d
[
t
]
=
s0
[
k
+
t
];
d
[
t
+
TILE_K
]
=
s1
[
k
+
t
];
}
// Zero-pad the rest
for
(
int32_t
t
=
tail
;
t
<
TILE_K
;
++
t
)
{
d
[
t
]
=
0
;
d
[
t
+
TILE_K
]
=
0
;
}
}
}
// 2x2 accumulator load/store with compile-time row count
template
<
int32_t
m_rows
>
FORCE_INLINE
float32x4_t
load_acc_2x2
(
float
*
base
,
int64_t
ldc
,
int
col_off
)
{
static_assert
(
m_rows
==
1
||
m_rows
==
2
);
float32x2_t
row0
=
vld1_f32
(
base
+
col_off
);
float32x2_t
row1
=
(
m_rows
==
2
)
?
vld1_f32
(
base
+
ldc
+
col_off
)
:
vdup_n_f32
(
0.
f
);
return
vcombine_f32
(
row0
,
row1
);
}
template
<
int32_t
m_rows
>
FORCE_INLINE
void
store_acc_2x2
(
float32x4_t
acc
,
float
*
base
,
int64_t
ldc
,
int
col_off
)
{
static_assert
(
m_rows
==
1
||
m_rows
==
2
);
vst1_f32
(
base
+
col_off
,
vget_low_f32
(
acc
));
if
constexpr
(
m_rows
==
2
)
{
vst1_f32
(
base
+
ldc
+
col_off
,
vget_high_f32
(
acc
));
}
}
// Initialize 4 column-pair accumulators for 2 rows (8 columns total)
#define INIT_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows, accum) \
do { \
if (accum) { \
if (m_rows == 2) { \
a0 = load_acc_2x2<2>(Crow, ldc, 0); \
a1 = load_acc_2x2<2>(Crow, ldc, 2); \
a2 = load_acc_2x2<2>(Crow, ldc, 4); \
a3 = load_acc_2x2<2>(Crow, ldc, 6); \
} else { \
a0 = load_acc_2x2<1>(Crow, ldc, 0); \
a1 = load_acc_2x2<1>(Crow, ldc, 2); \
a2 = load_acc_2x2<1>(Crow, ldc, 4); \
a3 = load_acc_2x2<1>(Crow, ldc, 6); \
} \
} else { \
a0 = a1 = a2 = a3 = vdupq_n_f32(0.f); \
} \
} while (0)
// Store 4 column-pair accumulators back to C matrix
#define STORE_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows) \
do { \
if (m_rows == 2) { \
store_acc_2x2<2>(a0, Crow, ldc, 0); \
store_acc_2x2<2>(a1, Crow, ldc, 2); \
store_acc_2x2<2>(a2, Crow, ldc, 4); \
store_acc_2x2<2>(a3, Crow, ldc, 6); \
} else { \
store_acc_2x2<1>(a0, Crow, ldc, 0); \
store_acc_2x2<1>(a1, Crow, ldc, 2); \
store_acc_2x2<1>(a2, Crow, ldc, 4); \
store_acc_2x2<1>(a3, Crow, ldc, 6); \
} \
} while (0)
// Perform 4 BFMMLA operations: acc += A @ B for 4 column-pairs
#define BFMMLA_COMPUTE_4(r0, r1, r2, r3, a, b0, b1, b2, b3) \
do { \
r0 = vbfmmlaq_f32(r0, a, b0); \
r1 = vbfmmlaq_f32(r1, a, b1); \
r2 = vbfmmlaq_f32(r2, a, b2); \
r3 = vbfmmlaq_f32(r3, a, b3); \
} while (0)
// Micro-kernel: updates a small fixed tile using BFMMLA.
// RP = number of row-pairs (1,2,4)
// Computes C[TILE_ROWS*RP, OUTPUT_COLS_PER_BLOCK] += A_packed @ B.
// A_packed interleaves RP row-pairs; B layout is driven by the attention phase:
// - AttentionGemmPhase::QK -> token-column layout (Q @ K^T)
// - AttentionGemmPhase::PV -> token-row layout (P @ V)
// K_static < 0 enables runtime K (PV only)
template
<
int32_t
RP
,
int32_t
K_static
,
AttentionGemmPhase
phase
>
FORCE_INLINE
void
gemm_rowpairs_x8_bfmmla_neon
(
const
bfloat16_t
*
const
*
__restrict
A_packed_rp
,
const
int32_t
*
__restrict
m_rows_rp
,
const
bfloat16_t
*
__restrict
B_blk
,
float
*
__restrict
C
,
int64_t
ldc
,
bool
accumulate
,
int64_t
b_stride
,
int32_t
K_runtime
=
0
)
{
static_assert
(
RP
==
1
||
RP
==
2
||
RP
==
4
,
"RP must be 1,2,4"
);
static_assert
(
K_static
<
0
||
K_static
%
TILE_K
==
0
,
"K must be divisible by TILE_K"
);
static_assert
(
K_static
>=
0
||
phase
==
AttentionGemmPhase
::
PV
,
"Runtime K only supported for PV"
);
constexpr
bool
runtime_k
=
(
K_static
<
0
);
const
int32_t
K_iters
=
runtime_k
?
(
K_runtime
/
TILE_K
)
:
(
K_static
/
TILE_K
);
const
int32_t
K_tail
=
runtime_k
?
(
K_runtime
%
TILE_K
)
:
0
;
if
(
!
runtime_k
)
{
// Help the compiler fold away unused K_runtime when K is compile-time
(
void
)
K_runtime
;
}
auto
*
C_al
=
C
;
const
auto
*
B_al
=
B_blk
;
// Setup A pointers
const
bfloat16_t
*
a_ptr
[
4
]
=
{
A_packed_rp
[
0
],
(
RP
>=
2
)
?
A_packed_rp
[
1
]
:
nullptr
,
(
RP
>=
4
)
?
A_packed_rp
[
2
]
:
nullptr
,
(
RP
>=
4
)
?
A_packed_rp
[
3
]
:
nullptr
,
};
// Setup B pointers based on layout
const
bfloat16_t
*
b_ptr
[
4
];
if
constexpr
(
phase
==
AttentionGemmPhase
::
PV
)
{
b_ptr
[
0
]
=
B_blk
+
0
*
b_stride
;
b_ptr
[
1
]
=
B_blk
+
1
*
b_stride
;
b_ptr
[
2
]
=
B_blk
+
2
*
b_stride
;
b_ptr
[
3
]
=
B_blk
+
3
*
b_stride
;
}
float32x4_t
acc
[
4
][
4
];
// Initialize accumulators
#define INIT_RP(rp) \
if constexpr (RP > rp) { \
INIT_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \
C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp], accumulate); \
}
INIT_RP
(
0
);
INIT_RP
(
1
);
INIT_RP
(
2
);
INIT_RP
(
3
);
#undef INIT_RP
// Main compute loop
for
(
int32_t
ki
=
0
;
ki
<
K_iters
;
++
ki
)
{
bfloat16x8_t
b0
,
b1
,
b2
,
b3
;
if
constexpr
(
phase
==
AttentionGemmPhase
::
PV
)
{
b0
=
vld1q_bf16
(
b_ptr
[
0
]
+
ki
*
V_INNER_STRIDE
);
b1
=
vld1q_bf16
(
b_ptr
[
1
]
+
ki
*
V_INNER_STRIDE
);
b2
=
vld1q_bf16
(
b_ptr
[
2
]
+
ki
*
V_INNER_STRIDE
);
b3
=
vld1q_bf16
(
b_ptr
[
3
]
+
ki
*
V_INNER_STRIDE
);
}
else
{
const
bfloat16_t
*
b_base
=
B_al
+
ki
*
b_stride
;
b0
=
vld1q_bf16
(
b_base
+
0
*
V_INNER_STRIDE
);
b1
=
vld1q_bf16
(
b_base
+
1
*
V_INNER_STRIDE
);
b2
=
vld1q_bf16
(
b_base
+
2
*
V_INNER_STRIDE
);
b3
=
vld1q_bf16
(
b_base
+
3
*
V_INNER_STRIDE
);
}
#define COMPUTE_RP(rp) \
if constexpr (RP > rp) { \
bfloat16x8_t a = vld1q_bf16(a_ptr[rp] + ki * PACK_ELEMENTS_PER_K_CHUNK); \
BFMMLA_COMPUTE_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], a, b0, \
b1, b2, b3); \
}
COMPUTE_RP
(
0
);
COMPUTE_RP
(
1
);
COMPUTE_RP
(
2
);
COMPUTE_RP
(
3
);
#undef COMPUTE_RP
}
// K tail for runtime PV: fallback path
if
constexpr
(
runtime_k
)
{
if
(
K_tail
>
0
)
{
const
int32_t
tail_offset
=
K_iters
*
V_INNER_STRIDE
;
const
int32_t
a_tail_offset
=
K_iters
*
PACK_ELEMENTS_PER_K_CHUNK
;
for
(
int32_t
kt
=
0
;
kt
<
K_tail
;
++
kt
)
{
float32x4_t
b_vecs
[
4
];
for
(
int32_t
p
=
0
;
p
<
4
;
++
p
)
{
const
bfloat16_t
*
bp
=
b_ptr
[
p
]
+
tail_offset
+
kt
*
TILE_COLS
;
const
float
b0
=
vcvtah_f32_bf16
(
bp
[
0
]);
const
float
b1
=
vcvtah_f32_bf16
(
bp
[
1
]);
const
float32x2_t
b_pair
=
vset_lane_f32
(
b1
,
vdup_n_f32
(
b0
),
1
);
b_vecs
[
p
]
=
vcombine_f32
(
b_pair
,
b_pair
);
}
#define TAIL_RP(rp) \
if constexpr (RP > rp) { \
const bfloat16_t* ap = A_packed_rp[rp] + a_tail_offset; \
float a_row0 = vcvtah_f32_bf16(ap[kt]); \
float a_row1 = \
(m_rows_rp[rp] == 2) ? vcvtah_f32_bf16(ap[kt + TILE_K]) : 0.0f; \
const float32x4_t a_vec = \
vcombine_f32(vdup_n_f32(a_row0), vdup_n_f32(a_row1)); \
for (int32_t p = 0; p < 4; ++p) { \
acc[rp][p] = vmlaq_f32(acc[rp][p], a_vec, b_vecs[p]); \
} \
}
TAIL_RP
(
0
);
TAIL_RP
(
1
);
TAIL_RP
(
2
);
TAIL_RP
(
3
);
#undef TAIL_RP
}
}
}
// Store results
#define STORE_RP(rp) \
if constexpr (RP > rp) { \
STORE_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \
C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp]); \
}
STORE_RP
(
0
);
STORE_RP
(
1
);
STORE_RP
(
2
);
STORE_RP
(
3
);
#undef STORE_RP
}
// Meso-kernel: packs a small MBxK slice of A, then tiles over N and calls the
// micro-kernel for each OUTPUT_COLS_PER_BLOCK chunk. K_static < 0 enables
// runtime K (PV only).
template
<
int32_t
MB
,
int32_t
N
,
int32_t
K_static
,
AttentionGemmPhase
phase
>
FORCE_INLINE
void
gemm_packA_compute_MB_xN
(
const
c10
::
BFloat16
*
__restrict
A
,
const
c10
::
BFloat16
*
__restrict
B
,
float
*
__restrict
C
,
int32_t
K_runtime
,
int64_t
lda
,
int64_t
ldc
,
int64_t
b_layout_stride
,
int64_t
b_reduction_stride
,
bool
accumulate
)
{
static_assert
(
MB
>=
1
&&
MB
<=
8
,
"MB must be in [1,8]"
);
static_assert
(
N
%
OUTPUT_COLS_PER_BLOCK
==
0
,
"N must be a multiple of OUTPUT_COLS_PER_BLOCK"
);
static_assert
(
K_static
<
0
||
K_static
%
TILE_K
==
0
,
"K must be divisible by TILE_K"
);
static_assert
(
K_static
>=
0
||
phase
==
AttentionGemmPhase
::
PV
,
"Runtime K only supported for PV"
);
constexpr
bool
runtime_k
=
(
K_static
<
0
);
const
int32_t
K_val
=
runtime_k
?
K_runtime
:
K_static
;
// Keep small packs on-stack to avoid heap churn
constexpr
int32_t
STACK_PACK_STRIDE
=
(
1024
/
TILE_K
)
*
PACK_ELEMENTS_PER_K_CHUNK
;
constexpr
int32_t
ROW_PAIRS
=
(
MB
+
1
)
/
TILE_ROWS
;
const
int32_t
pack_stride
=
runtime_k
?
((
K_val
+
TILE_K
-
1
)
/
TILE_K
)
*
PACK_ELEMENTS_PER_K_CHUNK
:
(
K_static
/
TILE_K
)
*
PACK_ELEMENTS_PER_K_CHUNK
;
alignas
(
64
)
c10
::
BFloat16
A_packed_stack
[
ROW_PAIRS
*
STACK_PACK_STRIDE
];
std
::
vector
<
c10
::
BFloat16
>
A_packed_heap
;
c10
::
BFloat16
*
A_packed
=
(
pack_stride
<=
STACK_PACK_STRIDE
)
?
A_packed_stack
:
(
A_packed_heap
.
resize
(
ROW_PAIRS
*
pack_stride
),
A_packed_heap
.
data
());
for
(
int32_t
rp
=
0
;
rp
<
ROW_PAIRS
;
++
rp
)
{
const
int32_t
m
=
rp
*
TILE_ROWS
;
const
int32_t
m_rows
=
(
m
+
1
<
MB
)
?
TILE_ROWS
:
1
;
const
c10
::
BFloat16
*
A0
=
A
+
m
*
lda
;
const
c10
::
BFloat16
*
A1
=
(
m_rows
==
TILE_ROWS
)
?
(
A
+
(
m
+
1
)
*
lda
)
:
A0
;
reshape_Q_2xK_for_bfmmla
(
A0
,
A1
,
A_packed
+
rp
*
pack_stride
,
K_val
);
}
for
(
int32_t
n
=
0
;
n
<
N
;
n
+=
OUTPUT_COLS_PER_BLOCK
)
{
const
c10
::
BFloat16
*
B_blk_c10
=
(
phase
==
AttentionGemmPhase
::
PV
)
?
(
B
+
(
n
/
TILE_COLS
)
*
b_layout_stride
)
:
(
B
+
(
n
/
OUTPUT_COLS_PER_BLOCK
)
*
b_layout_stride
);
const
bfloat16_t
*
B_blk
=
reinterpret_cast
<
const
bfloat16_t
*>
(
B_blk_c10
);
// Process row-pairs in groups of 4, 2, then 1
int32_t
row_pair_idx
=
0
;
#define PROCESS_RP_GROUP(group_size) \
for (; row_pair_idx + (group_size - 1) < ROW_PAIRS; \
row_pair_idx += group_size) { \
const bfloat16_t* Ap[group_size]; \
int32_t mr[group_size]; \
for (int32_t i = 0; i < group_size; ++i) { \
Ap[i] = reinterpret_cast<const bfloat16_t*>( \
A_packed + (row_pair_idx + i) * pack_stride); \
mr[i] = (((row_pair_idx + i) * TILE_ROWS + 1) < MB) ? TILE_ROWS : 1; \
} \
float* C_blk = C + (row_pair_idx * TILE_ROWS) * ldc + n; \
if constexpr (runtime_k) { \
gemm_rowpairs_x8_bfmmla_neon<group_size, -1, phase>( \
Ap, mr, B_blk, C_blk, ldc, accumulate, b_layout_stride, K_val); \
} else { \
gemm_rowpairs_x8_bfmmla_neon<group_size, K_static, phase>( \
Ap, mr, B_blk, C_blk, ldc, accumulate, \
(phase == AttentionGemmPhase::PV) ? b_layout_stride \
: b_reduction_stride); \
} \
}
PROCESS_RP_GROUP
(
4
);
PROCESS_RP_GROUP
(
2
);
PROCESS_RP_GROUP
(
1
);
#undef PROCESS_RP_GROUP
}
}
// Macro-kernel: iterates over M in MB={8,4,2,1} chunks.
// Supports compile-time K specialization when K >= 0; otherwise uses runtime K
// (runtime K path is only supported for PV).
template
<
AttentionGemmPhase
phase
,
int32_t
N
,
int32_t
K
=
-
1
>
FORCE_INLINE
void
gemm_macro_neon_bfmmla
(
const
c10
::
BFloat16
*
__restrict
A
,
const
c10
::
BFloat16
*
__restrict
B
,
float
*
__restrict
C
,
int32_t
M
,
int32_t
K_runtime
,
int64_t
lda
,
int64_t
ldc
,
int64_t
b_layout_stride
,
int64_t
b_reduction_stride
,
bool
accumulate
)
{
static_assert
(
N
%
OUTPUT_COLS_PER_BLOCK
==
0
,
"N must be a multiple of OUTPUT_COLS_PER_BLOCK"
);
if
constexpr
(
K
>=
0
)
{
static_assert
(
K
%
TILE_K
==
0
,
"K must be divisible by TILE_K"
);
for
(
int32_t
m
=
0
;
m
<
M
;)
{
const
int32_t
rem
=
M
-
m
;
const
c10
::
BFloat16
*
A_blk
=
A
+
m
*
lda
;
float
*
C_blk
=
C
+
m
*
ldc
;
#define DISPATCH_MB(mb) \
gemm_packA_compute_MB_xN<mb, N, K, phase>(A_blk, B, C_blk, 0, lda, ldc, \
b_layout_stride, \
b_reduction_stride, accumulate)
if
(
rem
>=
8
)
{
DISPATCH_MB
(
8
);
m
+=
8
;
}
else
if
(
rem
>=
4
)
{
DISPATCH_MB
(
4
);
m
+=
4
;
}
else
if
(
rem
>=
2
)
{
DISPATCH_MB
(
2
);
m
+=
2
;
}
else
{
DISPATCH_MB
(
1
);
m
+=
1
;
}
#undef DISPATCH_MB
}
}
else
{
static_assert
(
phase
==
AttentionGemmPhase
::
PV
,
"Runtime K specialization only supported for PV."
);
const
int32_t
K_val
=
K_runtime
;
for
(
int32_t
m
=
0
;
m
<
M
;)
{
const
int32_t
rem
=
M
-
m
;
const
c10
::
BFloat16
*
A_blk
=
A
+
m
*
lda
;
float
*
C_blk
=
C
+
m
*
ldc
;
#define DISPATCH_MB_RUNTIME(mb) \
gemm_packA_compute_MB_xN<mb, N, -1, phase>(A_blk, B, C_blk, K_val, lda, ldc, \
b_layout_stride, \
b_reduction_stride, accumulate)
if
(
rem
>=
8
)
{
DISPATCH_MB_RUNTIME
(
8
);
m
+=
8
;
}
else
if
(
rem
>=
4
)
{
DISPATCH_MB_RUNTIME
(
4
);
m
+=
4
;
}
else
if
(
rem
>=
2
)
{
DISPATCH_MB_RUNTIME
(
2
);
m
+=
2
;
}
else
{
DISPATCH_MB_RUNTIME
(
1
);
m
+=
1
;
}
#undef DISPATCH_MB_RUNTIME
}
}
}
#undef INIT_ACC_ROWPAIR_4
#undef STORE_ACC_ROWPAIR_4
#undef BFMMLA_COMPUTE_4
}
// namespace
// TileGemm Adapter for Attention
template
<
typename
kv_cache_t
,
int32_t
BlockTokens
,
int32_t
HeadDim
>
class
TileGemmNEONBFMMLA
{
public:
template
<
AttentionGemmPhase
phase
,
int32_t
head_dim_ct
>
FORCE_INLINE
static
void
gemm
(
const
int32_t
m_size
,
void
*
__restrict__
a_tile
,
kv_cache_t
*
__restrict__
b_tile
,
float
*
__restrict__
c_tile
,
const
int64_t
lda
,
[[
maybe_unused
]]
const
int64_t
ldb
,
const
int64_t
ldc
,
[[
maybe_unused
]]
const
int32_t
block_size
,
[[
maybe_unused
]]
const
int32_t
dynamic_k_size
,
const
bool
accum_c
)
{
static_assert
(
BlockTokens
%
OUTPUT_COLS_PER_BLOCK
==
0
);
// BFMMLA kernels require compile-time head_dim; keep head_dim_ct only for
// API parity with other tile_gemm implementations.
if
constexpr
(
head_dim_ct
>=
0
)
{
static_assert
(
head_dim_ct
==
HeadDim
,
"BFMMLA expects head_dim_ct to match HeadDim; PV passes "
"-1 for API parity."
);
}
if
constexpr
(
phase
==
AttentionGemmPhase
::
QK
)
{
const
int64_t
b_reduction_stride
=
K_INNER_STRIDE
;
const
int64_t
b_token_block_stride
=
(
HeadDim
/
TILE_K
)
*
K_INNER_STRIDE
;
gemm_macro_neon_bfmmla
<
AttentionGemmPhase
::
QK
,
BlockTokens
,
HeadDim
>
(
reinterpret_cast
<
const
c10
::
BFloat16
*>
(
a_tile
),
b_tile
,
c_tile
,
m_size
,
0
,
lda
,
ldc
,
b_token_block_stride
,
b_reduction_stride
,
accum_c
);
}
else
{
const
int64_t
b_pair_stride
=
(
block_size
/
V_TOKENS_PER_ROW_BLOCK
)
*
V_INNER_STRIDE
;
// PV gemm with runtime K specialization
switch
(
dynamic_k_size
)
{
case
32
:
gemm_macro_neon_bfmmla
<
AttentionGemmPhase
::
PV
,
HeadDim
,
32
>
(
reinterpret_cast
<
const
c10
::
BFloat16
*>
(
a_tile
),
b_tile
,
c_tile
,
m_size
,
32
,
lda
,
ldc
,
b_pair_stride
,
0
,
accum_c
);
break
;
case
128
:
gemm_macro_neon_bfmmla
<
AttentionGemmPhase
::
PV
,
HeadDim
,
128
>
(
reinterpret_cast
<
const
c10
::
BFloat16
*>
(
a_tile
),
b_tile
,
c_tile
,
m_size
,
128
,
lda
,
ldc
,
b_pair_stride
,
0
,
accum_c
);
break
;
case
256
:
gemm_macro_neon_bfmmla
<
AttentionGemmPhase
::
PV
,
HeadDim
,
256
>
(
reinterpret_cast
<
const
c10
::
BFloat16
*>
(
a_tile
),
b_tile
,
c_tile
,
m_size
,
256
,
lda
,
ldc
,
b_pair_stride
,
0
,
accum_c
);
break
;
default:
gemm_macro_neon_bfmmla
<
AttentionGemmPhase
::
PV
,
HeadDim
>
(
reinterpret_cast
<
const
c10
::
BFloat16
*>
(
a_tile
),
b_tile
,
c_tile
,
m_size
,
dynamic_k_size
,
lda
,
ldc
,
b_pair_stride
,
0
,
accum_c
);
break
;
}
}
}
};
// Shared ASIMD BFMMLA implementation (BF16 only). The block size alignment and
// ISA tag are template parameters so we can reuse the same kernels for
// different NEON configurations.
template
<
int64_t
block_size_alignment
,
ISA
isa_type
,
int64_t
head_dim
>
class
AttentionImplNEONBFMMLA
{
public:
using
query_t
=
c10
::
BFloat16
;
using
q_buffer_t
=
c10
::
BFloat16
;
using
kv_cache_t
=
c10
::
BFloat16
;
using
logits_buffer_t
=
float
;
using
partial_output_buffer_t
=
float
;
using
prob_buffer_t
=
c10
::
BFloat16
;
static
constexpr
int64_t
BlockSizeAlignment
=
block_size_alignment
;
// HeadDimAlignment equals head_dim so that the PV phase processes
// the full head dimension in a single gemm call.
static
constexpr
int64_t
HeadDimAlignment
=
head_dim
;
static
constexpr
int64_t
MaxQHeadNumPerIteration
=
16
;
static
constexpr
int64_t
HeadDim
=
head_dim
;
static
constexpr
ISA
ISAType
=
isa_type
;
static
constexpr
bool
scale_on_logits
=
false
;
static_assert
(
HeadDim
%
OUTPUT_COLS_PER_BLOCK
==
0
);
static_assert
(
BlockSizeAlignment
%
OUTPUT_COLS_PER_BLOCK
==
0
);
static_assert
(
HeadDim
%
TILE_K
==
0
,
"HeadDim must be a multiple of TILE_K"
);
public:
template
<
template
<
typename
tile_gemm_t
>
typename
attention
>
FORCE_INLINE
void
execute_attention
(
DEFINE_CPU_ATTENTION_PARAMS
)
{
attention
<
TileGemmNEONBFMMLA
<
kv_cache_t
,
static_cast
<
int32_t
>
(
BlockSizeAlignment
),
static_cast
<
int32_t
>
(
HeadDim
)
>>
attention_iteration
;
attention_iteration
(
CPU_ATTENTION_PARAMS
);
}
// Key cache stride per token group (TokenColumn layout; QK)
static
constexpr
int64_t
k_cache_token_group_stride
(
[[
maybe_unused
]]
const
int32_t
block_size
)
{
static_assert
(
BlockSizeAlignment
%
K_TOKENS_PER_GROUP
==
0
);
return
(
BlockSizeAlignment
/
K_TOKENS_PER_GROUP
)
*
((
head_dim
/
TILE_K
)
*
K_INNER_STRIDE
);
}
// Value cache stride per token group (TokenRow layout; PV)
static
constexpr
int64_t
v_cache_token_group_stride
(
[[
maybe_unused
]]
const
int32_t
block_size
)
{
static_assert
(
BlockSizeAlignment
%
V_TOKENS_PER_ROW_BLOCK
==
0
);
return
(
BlockSizeAlignment
/
V_TOKENS_PER_ROW_BLOCK
)
*
V_INNER_STRIDE
;
}
// The stride to move to the "next" head_dim group
// is the full V cache size per head, since HeadDimAlignment == head_dim.
// Hence, the stride is not used in this case
static
constexpr
int64_t
v_cache_head_group_stride
(
[[
maybe_unused
]]
const
int32_t
block_size
)
{
return
head_dim
*
block_size
;
}
// Convert Q heads to BF16 and apply scale factor using native BF16 intrinsics
static
void
copy_q_heads_tile
(
c10
::
BFloat16
*
__restrict__
src
,
c10
::
BFloat16
*
__restrict__
q_buffer
,
const
int32_t
q_num
,
const
int32_t
q_heads_per_kv
,
const
int64_t
q_num_stride
,
const
int64_t
q_head_stride
,
float
scale
)
{
constexpr
int32_t
dim
=
static_cast
<
int32_t
>
(
head_dim
);
const
float32x4_t
scale_vec
=
vdupq_n_f32
(
scale
);
for
(
int32_t
qi
=
0
;
qi
<
q_num
;
++
qi
)
{
for
(
int32_t
hi
=
0
;
hi
<
q_heads_per_kv
;
++
hi
)
{
c10
::
BFloat16
*
__restrict__
curr_q
=
src
+
qi
*
q_num_stride
+
hi
*
q_head_stride
;
c10
::
BFloat16
*
__restrict__
dst
=
q_buffer
+
qi
*
q_heads_per_kv
*
head_dim
+
hi
*
head_dim
;
for
(
int32_t
i
=
0
;
i
<
dim
;
i
+=
OUTPUT_COLS_PER_BLOCK
)
{
bfloat16x8_t
in8
=
vld1q_bf16
(
reinterpret_cast
<
const
bfloat16_t
*>
(
curr_q
+
i
));
float32x4_t
lo
=
vmulq_f32
(
vcvtq_low_f32_bf16
(
in8
),
scale_vec
);
float32x4_t
hi
=
vmulq_f32
(
vcvtq_high_f32_bf16
(
in8
),
scale_vec
);
bfloat16x4_t
lo_b
=
vcvt_bf16_f32
(
lo
);
bfloat16x4_t
hi_b
=
vcvt_bf16_f32
(
hi
);
bfloat16x8_t
out
=
vcombine_bf16
(
lo_b
,
hi_b
);
vst1q_bf16
(
reinterpret_cast
<
bfloat16_t
*>
(
dst
+
i
),
out
);
}
}
}
}
public:
// Reshape and cache K/V into BFMMLA-optimized layouts
// K cache:
// [block_size/K_TOKENS_PER_GROUP][head_dim/TILE_K][K_INNER_STRIDE]
// - TokenColumn
// V cache:
// [head_dim/TILE_COLS][block_size/V_TOKENS_PER_ROW_BLOCK][V_INNER_STRIDE]
// - TokenRows
static
void
reshape_and_cache
(
const
c10
::
BFloat16
*
__restrict__
key
,
const
c10
::
BFloat16
*
__restrict__
value
,
c10
::
BFloat16
*
__restrict__
key_cache
,
c10
::
BFloat16
*
__restrict__
value_cache
,
const
int64_t
*
__restrict__
slot_mapping
,
const
int64_t
token_num
,
const
int64_t
key_token_num_stride
,
const
int64_t
value_token_num_stride
,
const
int64_t
head_num
,
const
int64_t
key_head_num_stride
,
const
int64_t
value_head_num_stride
,
[[
maybe_unused
]]
const
int64_t
num_blocks
,
const
int64_t
num_blocks_stride
,
const
int64_t
cache_head_num_stride
,
const
int64_t
block_size
,
[[
maybe_unused
]]
const
int64_t
block_size_stride
)
{
const
int64_t
k_block_stride
=
(
head_dim
/
TILE_K
)
*
K_INNER_STRIDE
;
const
int64_t
v_pair_stride
=
(
block_size
/
V_TOKENS_PER_ROW_BLOCK
)
*
V_INNER_STRIDE
;
#pragma omp parallel for
for
(
int64_t
head_idx
=
0
;
head_idx
<
head_num
;
++
head_idx
)
{
for
(
int64_t
token_idx
=
0
;
token_idx
<
token_num
;
++
token_idx
)
{
const
int64_t
pos
=
slot_mapping
[
token_idx
];
if
(
pos
<
0
)
continue
;
const
int64_t
block_idx
=
pos
/
block_size
;
const
int64_t
block_offset
=
pos
%
block_size
;
// Key cache: TokenColumn QK
{
const
c10
::
BFloat16
*
__restrict
key_src
=
key
+
token_idx
*
key_token_num_stride
+
head_idx
*
key_head_num_stride
;
c10
::
BFloat16
*
__restrict
key_base
=
key_cache
+
block_idx
*
num_blocks_stride
+
head_idx
*
cache_head_num_stride
;
const
int64_t
block_in_block
=
block_offset
/
K_TOKENS_PER_GROUP
;
const
int64_t
pair_in_block
=
(
block_offset
%
K_TOKENS_PER_GROUP
)
/
TILE_COLS
;
const
int64_t
lane_base
=
(
block_offset
&
1
)
?
TILE_K
:
0
;
c10
::
BFloat16
*
__restrict
block_base
=
key_base
+
block_in_block
*
k_block_stride
;
for
(
int64_t
hd4
=
0
;
hd4
<
head_dim
/
TILE_K
;
++
hd4
)
{
uint16_t
*
dst_u16
=
reinterpret_cast
<
uint16_t
*>
(
block_base
+
hd4
*
K_INNER_STRIDE
+
pair_in_block
*
V_INNER_STRIDE
+
lane_base
);
const
uint16_t
*
src_u16
=
reinterpret_cast
<
const
uint16_t
*>
(
key_src
+
hd4
*
TILE_K
);
vst1_u16
(
dst_u16
,
vld1_u16
(
src_u16
));
}
}
// Value cache: TokenRow PV
{
const
c10
::
BFloat16
*
__restrict
value_src
=
value
+
token_idx
*
value_token_num_stride
+
head_idx
*
value_head_num_stride
;
c10
::
BFloat16
*
__restrict
value_base
=
value_cache
+
block_idx
*
num_blocks_stride
+
head_idx
*
cache_head_num_stride
;
const
int64_t
row_block
=
block_offset
/
V_TOKENS_PER_ROW_BLOCK
;
const
int64_t
lane
=
block_offset
&
(
V_TOKENS_PER_ROW_BLOCK
-
1
);
c10
::
BFloat16
*
__restrict
row_block_base
=
value_base
+
row_block
*
V_INNER_STRIDE
;
for
(
int64_t
hd2
=
0
;
hd2
<
head_dim
/
TILE_COLS
;
++
hd2
)
{
c10
::
BFloat16
*
__restrict
dst_val
=
row_block_base
+
hd2
*
v_pair_stride
;
const
uint16_t
*
src_u16
=
reinterpret_cast
<
const
uint16_t
*>
(
value_src
);
uint16_t
*
dst_u16
=
reinterpret_cast
<
uint16_t
*>
(
dst_val
);
dst_u16
[
lane
]
=
src_u16
[
hd2
*
TILE_COLS
+
0
];
dst_u16
[
lane
+
V_TOKENS_PER_ROW_BLOCK
]
=
src_u16
[
hd2
*
TILE_COLS
+
1
];
}
}
}
}
}
};
}
// namespace cpu_attention
#endif // CPU_ATTN_ASIMD_BFMMLA_HPP
vllm/v1/attention/backends/cpu_attn.py
View file @
1363e3d6
...
@@ -487,10 +487,12 @@ def _get_attn_isa(
...
@@ -487,10 +487,12 @@ def _get_attn_isa(
if
head_size
is
not
None
and
head_size
%
32
!=
0
and
head_size
%
16
==
0
:
if
head_size
is
not
None
and
head_size
%
32
!=
0
and
head_size
%
16
==
0
:
return
"vec16"
return
"vec16"
supports_amx
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
supports_amx
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
supports_arm
=
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
ARM
if
supports_amx
and
dtype
in
(
torch
.
bfloat16
,)
and
block_size
%
32
==
0
:
if
supports_amx
and
dtype
in
(
torch
.
bfloat16
,)
and
block_size
%
32
==
0
:
return
"amx"
return
"amx"
elif
block_size
%
32
==
0
:
elif
block_size
%
32
==
0
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
ARM
:
if
supports_arm
:
# support ARM NEON FMLA and BFMMLA (bf16) for block size 32
return
"neon"
return
"neon"
else
:
else
:
return
"vec"
return
"vec"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment