Unverified Commit 1363e3d6 authored by Gassan Salama's avatar Gassan Salama Committed by GitHub
Browse files

[cpu][performance] CPU Paged Attention NEON BFMMLA BF16 Implementation (#32263)


Signed-off-by: default avatarGassan <gassan.salama@arm.com>
parent 96552566
......@@ -1107,7 +1107,8 @@ class AttentionMainLoop {
if (sliding_window_left != -1) {
pos = std::max(pos, curr_token_pos - sliding_window_left);
}
return pos;
// Clamp to tile end to avoid OOB when window starts past the tile
return std::min(pos, kv_tile_end_pos);
}();
int32_t right_kv_pos = [&]() {
......
......@@ -4,6 +4,9 @@
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <type_traits>
#ifdef ARM_BF16_SUPPORT
#include "cpu_attn_neon_bfmmla.hpp"
#endif
namespace cpu_attention {
namespace {
......@@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
#endif
}
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with ASIMD FMLAs
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #FMLAs = (K // 4) * (4 * 2 * M)
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
......@@ -381,6 +384,18 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
}
}
};
#ifdef ARM_BF16_SUPPORT
// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment.
template <int64_t head_dim>
class AttentionImpl<ISA::NEON, c10::BFloat16, head_dim>
: public AttentionImplNEONBFMMLA<BLOCK_SIZE_ALIGNMENT, ISA::NEON,
head_dim> {};
#endif
} // namespace cpu_attention
#endif // #ifndef CPU_ATTN_NEON_HPP
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif // #ifndef CPU_ATTN_ASIMD_HPP
This diff is collapsed.
......@@ -487,10 +487,12 @@ def _get_attn_isa(
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported()
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
elif block_size % 32 == 0:
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
if supports_arm:
# support ARM NEON FMLA and BFMMLA (bf16) for block size 32
return "neon"
else:
return "vec"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment