Unverified Commit e69c990c authored by Radu Salavat's avatar Radu Salavat Committed by GitHub
Browse files

[Feature][CPU Backend]: Optimize ARM vectorization backend (#30329)


Signed-off-by: default avatarRadu Salavat <radu.salavat@arm.com>
parent 5eac9a1b
......@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
// ARM only supports BF16 with ARMv8.6-A extension
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
#else
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
#if !defined(__powerpc__) && !defined(__s390x__)
template <>
......@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
if (use_sink) {
alignas(64) float s_aux_fp32[16];
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// ARM without native BF16 support: manual conversion
for (int i = 0; i < 16; ++i) {
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
}
#else
// All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
#endif
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
......
This diff is collapsed.
......@@ -14,13 +14,11 @@ struct KernelVecType<float> {
using cvt_vec_type = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct KernelVecType<c10::BFloat16> {
using load_vec_type = vec_op::BF16Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
#endif
template <>
struct KernelVecType<c10::Half> {
......
......@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
using qk_vec_type = vec_op::BF16Vec32;
using v_load_vec_type = vec_op::BF16Vec16;
};
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// pass
#else
#elif defined(__aarch64__)
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16;
......
......@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
#if !defined(__powerpc__)
template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment