Unverified Commit 1363e3d6 authored by Gassan Salama's avatar Gassan Salama Committed by GitHub
Browse files

[cpu][performance] CPU Paged Attention NEON BFMMLA BF16 Implementation (#32263)


Signed-off-by: default avatarGassan <gassan.salama@arm.com>
parent 96552566
...@@ -1107,7 +1107,8 @@ class AttentionMainLoop { ...@@ -1107,7 +1107,8 @@ class AttentionMainLoop {
if (sliding_window_left != -1) { if (sliding_window_left != -1) {
pos = std::max(pos, curr_token_pos - sliding_window_left); pos = std::max(pos, curr_token_pos - sliding_window_left);
} }
return pos; // Clamp to tile end to avoid OOB when window starts past the tile
return std::min(pos, kv_tile_end_pos);
}(); }();
int32_t right_kv_pos = [&]() { int32_t right_kv_pos = [&]() {
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
#include "cpu_attn_impl.hpp" #include "cpu_attn_impl.hpp"
#include <arm_neon.h> #include <arm_neon.h>
#include <type_traits> #include <type_traits>
#ifdef ARM_BF16_SUPPORT
#include "cpu_attn_neon_bfmmla.hpp"
#endif
namespace cpu_attention { namespace cpu_attention {
namespace { namespace {
...@@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p, ...@@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
#endif #endif
} }
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs // Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with ASIMD FMLAs
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2) // #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #FMLAs = (K // 4) * (4 * 2 * M) // #FMLAs = (K // 4) * (4 * 2 * M)
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads // We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
...@@ -381,6 +384,18 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> { ...@@ -381,6 +384,18 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
} }
} }
}; };
#ifdef ARM_BF16_SUPPORT
// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment.
template <int64_t head_dim>
class AttentionImpl<ISA::NEON, c10::BFloat16, head_dim>
: public AttentionImplNEONBFMMLA<BLOCK_SIZE_ALIGNMENT, ISA::NEON,
head_dim> {};
#endif
} // namespace cpu_attention } // namespace cpu_attention
#endif // #ifndef CPU_ATTN_NEON_HPP #undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif // #ifndef CPU_ATTN_ASIMD_HPP
This diff is collapsed.
...@@ -487,10 +487,12 @@ def _get_attn_isa( ...@@ -487,10 +487,12 @@ def _get_attn_isa(
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0: if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16" return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported() supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx" return "amx"
elif block_size % 32 == 0: elif block_size % 32 == 0:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: if supports_arm:
# support ARM NEON FMLA and BFMMLA (bf16) for block size 32
return "neon" return "neon"
else: else:
return "vec" return "vec"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment