Unverified Commit e69c990c authored by Radu Salavat's avatar Radu Salavat Committed by GitHub
Browse files

[Feature][CPU Backend]: Optimize ARM vectorization backend (#30329)


Signed-off-by: default avatarRadu Salavat <radu.salavat@arm.com>
parent 5eac9a1b
......@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
// ARM only supports BF16 with ARMv8.6-A extension
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
#else
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
#if !defined(__powerpc__) && !defined(__s390x__)
template <>
......@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
if (use_sink) {
alignas(64) float s_aux_fp32[16];
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// ARM without native BF16 support: manual conversion
for (int i = 0; i < 16; ++i) {
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
}
#else
// All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
#endif
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
......
#include <cmath>
#include <type_traits>
#include <arm_neon.h>
#include <torch/all.h>
#include <cmath>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#if defined(__APPLE__)
#include "omp.h"
#endif
using namespace at::vec;
namespace vec_op {
#ifdef ARM_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
......@@ -45,667 +46,632 @@ constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
inline constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
};
template <typename T, typename... Ts>
struct is_one_of : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {};
struct FP32Vec8;
struct FP32Vec16;
template <typename T, typename... Ts>
inline constexpr bool is_one_of_v = is_one_of<T, Ts...>::value;
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
struct uninit_t {
explicit constexpr uninit_t() = default;
};
inline constexpr uninit_t uninit{};
float16x8_t reg;
template <typename NxVectorizedTVecReg, typename T, int VEC_ELEM_NUM>
union AliasReg {
NxVectorizedTVecReg reg;
T values[VEC_ELEM_NUM];
};
explicit FP16Vec8(const void* ptr)
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
// Template over at::vec::Vectorized<T> to support
// multiple vectorised registers into 1 of length VEC_REG_NUM val
template <int N, typename T>
struct NxVectorizedTVecReg {
using value_t = T;
using VectorizedT = Vectorized<T>;
explicit FP16Vec8(const FP32Vec8&);
VectorizedT val[N];
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
};
NxVectorizedTVecReg() = default;
NxVectorizedTVecReg(const NxVectorizedTVecReg&) = default;
NxVectorizedTVecReg(NxVectorizedTVecReg&&) = default;
NxVectorizedTVecReg& operator=(const NxVectorizedTVecReg&) = default;
NxVectorizedTVecReg& operator=(NxVectorizedTVecReg&&) = default;
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
explicit NxVectorizedTVecReg(uninit_t) noexcept {};
float16x8x2_t reg;
FORCE_INLINE explicit NxVectorizedTVecReg(const VectorizedT& vec_t) {
unroll_loop<int, N>([&](int i) { val[i] = vec_t; });
};
explicit FP16Vec16(const void* ptr) {
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
FORCE_INLINE explicit NxVectorizedTVecReg(T v) noexcept {
VectorizedT vv(v);
unroll_loop<int, N>([&](int i) { val[i] = vv; });
}
// ASIMD does not support non-temporal loads
explicit FP16Vec16(bool, const void* ptr) : FP16Vec16(ptr) {}
FORCE_INLINE explicit NxVectorizedTVecReg(const void* ptr) { load(ptr); }
explicit NxVectorizedTVecReg(const void* ptr, const int elem_num) {
load(ptr, elem_num);
}
explicit FP16Vec16(const FP32Vec16& vec);
void save(void* ptr) const {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
static constexpr int size() noexcept { return N * VectorizedT::size(); }
FORCE_INLINE void save(void* ptr) const {
value_t* base = reinterpret_cast<value_t*>(ptr);
unroll_loop<int, N>(
[&](int i) { val[i].store(base + i * VectorizedT::size()); });
}
FORCE_INLINE void load(const void* ptr) {
const value_t* base = reinterpret_cast<const value_t*>(ptr);
unroll_loop<int, N>([&](int i) {
val[i] = VectorizedT::loadu(base + i * VectorizedT::size());
});
}
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
FORCE_INLINE void save(void* ptr, const int elem_num) const {
value_t* base = reinterpret_cast<value_t*>(ptr);
save_partial(base, elem_num);
}
if (full_blocks > 0) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
if (full_blocks > 1) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
}
FORCE_INLINE void load(const void* ptr, const int elem_num) {
const value_t* base = reinterpret_cast<const value_t*>(ptr);
load_partial(base, elem_num);
}
// Note: below is the unrolled version of the following code:
//
// for (int i = 0; i < remainder; ++i) {
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
// vgetq_lane_f16(temp, i);
// }
//
// For macOS build (Clang), the arm/neon intrinsics function
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
// time.
FORCE_INLINE void save_partial(value_t* base, int elem_num) const {
const int w = VectorizedT::size();
int full = elem_num / w;
int rem = elem_num % w;
for (int i = 0; i < full; i++) val[i].store(base + i * w);
if (rem) val[full].store(base + full * w, rem);
}
if (remainder > 0) {
float16x8_t temp = reg.val[full_blocks];
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
switch (remainder) {
case 1:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
break;
case 2:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
break;
case 3:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
break;
case 4:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
break;
case 5:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
break;
case 6:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
break;
case 7:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
break;
FORCE_INLINE void load_partial(const value_t* base, int elem_num) {
const int w = VectorizedT::size();
int full = elem_num / w;
int rem = elem_num % w;
for (int i = 0; i < full; i++) val[i] = VectorizedT::loadu(base + i * w);
if (rem) val[full] = VectorizedT::loadu(base + full * w, rem);
}
default:
break;
template <VectorizedT (VectorizedT::*torch_vec_func)() const,
value_t (*std_func)(value_t)>
FORCE_INLINE NxVectorizedTVecReg opt_vec_func_impl() const {
NxVectorizedTVecReg result;
if constexpr (torch_vec_func != nullptr) {
unroll_loop<int, N>(
[&](int i) { result.val[i] = (val[i].*torch_vec_func)(); });
} else {
for (int i = 0; i < N; i++) {
alignas(64) value_t buf[VectorizedT::size()];
val[i].store(buf);
for (int j = 0; j < VectorizedT::size(); ++j) {
buf[j] = std_func(buf[j]);
}
result.val[i] = VectorizedT::loadu(buf);
}
}
return result;
}
};
#ifdef ARM_BF16_SUPPORT
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
template <typename DerivedClassT, int N, typename T>
struct VectorizedRegWrapper {
using ScalarT = T;
using VectorizedT = Vectorized<T>;
using NxVectorizedTArray = NxVectorizedTVecReg<N, T>;
constexpr static int VEC_REG_NUM = N;
constexpr static int VEC_ELEM_NUM = VEC_REG_NUM * VectorizedT::size();
constexpr static int get_elem_num() { return VEC_ELEM_NUM; };
NxVectorizedTArray reg;
VectorizedRegWrapper() noexcept = default;
explicit VectorizedRegWrapper(uninit_t) noexcept : reg{uninit} {};
explicit VectorizedRegWrapper(T v) : reg(v) {};
explicit VectorizedRegWrapper(const void* ptr) : reg(ptr) {};
explicit VectorizedRegWrapper(const void* ptr, const int elem_num)
: reg(ptr, elem_num) {};
explicit VectorizedRegWrapper(const VectorizedT& r) : reg(r) {};
explicit VectorizedRegWrapper(const NxVectorizedTArray& r) : reg(r) {};
VectorizedRegWrapper(const VectorizedRegWrapper&) = default;
VectorizedRegWrapper(VectorizedRegWrapper&&) = default;
VectorizedRegWrapper& operator=(VectorizedRegWrapper&&) = default;
VectorizedRegWrapper& operator=(const VectorizedRegWrapper&) = default;
FORCE_INLINE void save(void* ptr) const { reg.save(ptr); }
void save(void* ptr, const int elem_num) const { reg.save(ptr, elem_num); }
// Define optimized functions using at::vec::Vectorized<T> where possible
// Fallback to std:: functions when not available
#define OPT_TORCH_IMPL(FUNC_NAME, STD_FUNC_NAME, TORCH_FUNC_NAME, ...) \
FORCE_INLINE DerivedClassT FUNC_NAME() const { \
if constexpr (is_one_of_v<T, __VA_ARGS__>) { \
return DerivedClassT{ \
reg.template opt_vec_func_impl<&VectorizedT::TORCH_FUNC_NAME, \
std::STD_FUNC_NAME>()}; \
} else { \
return DerivedClassT{reg.template opt_vec_func_impl< \
nullptr, static_cast<ScalarT (*)(ScalarT)>(&std::STD_FUNC_NAME)>()}; \
} \
}
bfloat16x8_t reg;
// Define optimized functions for datatypes passed in __VA_ARGS__
OPT_TORCH_IMPL(abs, abs, abs, c10::Half, float)
OPT_TORCH_IMPL(er, erf, erf, float)
OPT_TORCH_IMPL(exp, exp, fexp_u20, float)
OPT_TORCH_IMPL(exp_u20, exp, exp_u20, float)
OPT_TORCH_IMPL(sin, sin, sin, float)
OPT_TORCH_IMPL(sinh, sinh, sinh, float)
OPT_TORCH_IMPL(cos, cos, cos, float)
OPT_TORCH_IMPL(cosh, cosh, cosh, float)
OPT_TORCH_IMPL(log, log, log, float)
OPT_TORCH_IMPL(log10, log10, log10, float)
OPT_TORCH_IMPL(sqrt, sqrt, sqrt, c10::Half, float)
OPT_TORCH_IMPL(tan, tan, tan, float)
OPT_TORCH_IMPL(tanh, tanh, tanh, float)
#undef OPT_TORCH_IMPL
};
explicit BF16Vec8(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
// forward declare vectorised dtypes
struct FP32Vec8;
struct FP32Vec16;
struct FP16Vec8;
struct FP16Vec16;
struct BF16Vec8;
struct BF16Vec16;
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
struct INT8Vec16;
struct INT32Vec16;
explicit BF16Vec8(const FP32Vec8&);
template <typename T>
struct VecType {
using vec_type = void;
};
explicit BF16Vec8(float32x4x2_t v)
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
bfloat16x8x2_t reg;
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
explicit BF16Vec16(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
// ASIMD does not support non-temporal loads
explicit BF16Vec16(bool, const void* ptr) : BF16Vec16(ptr) {}
struct FP16Vec8 : public VectorizedRegWrapper<FP16Vec8, 1, c10::Half> {
using Base = VectorizedRegWrapper<FP16Vec8, 1, c10::Half>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
explicit FP16Vec8(const FP32Vec8&);
};
explicit BF16Vec16(const FP32Vec16&);
struct FP16Vec16 : public VectorizedRegWrapper<FP16Vec16, 2, c10::Half> {
using Base = VectorizedRegWrapper<FP16Vec16, 2, c10::Half>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit BF16Vec16(float32x4x4_t v)
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
// ASIMD does not support non-temporal loads
explicit FP16Vec16(bool, const void* ptr) : Base(ptr) {}
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_bf16(
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
bfloat16x8_t temp = reg.val[full_blocks];
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
}
};
explicit FP16Vec16(const FP32Vec16& vec);
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
struct BF16Vec8 : public VectorizedRegWrapper<BF16Vec8, 1, c10::BFloat16> {
using Base = VectorizedRegWrapper<BF16Vec8, 1, c10::BFloat16>;
using VectorizedT = typename Base::VectorizedT;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
bfloat16x8x4_t reg;
explicit BF16Vec8(at_bfloat16x8_t data) : Base(VectorizedT(data)) {};
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
explicit BF16Vec8(float32x4x2_t v) {
reg.val[0] = convert_float_bfloat16(v.val[0], v.val[1]);
};
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
explicit BF16Vec8(const FP32Vec8&);
};
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
struct BF16Vec16 : public VectorizedRegWrapper<BF16Vec16, 2, c10::BFloat16> {
using Base = VectorizedRegWrapper<BF16Vec16, 2, c10::BFloat16>;
using VectorizedT = typename Base::VectorizedT;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_bf16(
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
bfloat16x8_t temp = reg.val[full_blocks];
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
base[0] = vgetq_lane_bf16(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
}
// ASIMD does not support non-temporal loads
explicit BF16Vec16(bool, const void* ptr) : Base(ptr) {}
explicit BF16Vec16(float32x4x4_t v) {
reg.val[0] = convert_float_bfloat16(v.val[0], v.val[1]);
reg.val[1] = convert_float_bfloat16(v.val[2], v.val[3]);
};
};
#endif
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
explicit BF16Vec16(const FP32Vec16&);
};
union AliasReg {
float32x4_t reg;
float values[VEC_ELEM_NUM];
struct BF16Vec32 : public VectorizedRegWrapper<BF16Vec32, 4, c10::BFloat16> {
using Base = VectorizedRegWrapper<BF16Vec32, 4, c10::BFloat16>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit BF16Vec32(const BF16Vec8& vec8_data) {
reg.val[0] = vec8_data.reg.val[0];
reg.val[1] = vec8_data.reg.val[0];
reg.val[2] = vec8_data.reg.val[0];
reg.val[3] = vec8_data.reg.val[0];
};
};
float32x4_t reg;
explicit FP32Vec4(float v) : reg(vdupq_n_f32(v)) {};
struct FP32Vec4 : public VectorizedRegWrapper<FP32Vec4, 1, float> {
using Base = VectorizedRegWrapper<FP32Vec4, 1, float>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
using VectorizedT = typename Base::VectorizedT;
using Vectorized1x4f = typename Base::NxVectorizedTArray;
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
FP32Vec4() : Base() {};
explicit FP32Vec4(float v) : Base(v) {};
explicit FP32Vec4(float32x4_t data) : reg(data) {};
explicit FP32Vec4(float32x4_t data) : Base(VectorizedT(data)) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
explicit FP32Vec4(const FP32Vec4& data) : Base(data) {};
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
float32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
float32x4x2_t reg;
struct FP32Vec8 : public VectorizedRegWrapper<FP32Vec8, 2, float> {
using Base = VectorizedRegWrapper<FP32Vec8, 2, float>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
using Base::VEC_REG_NUM;
explicit FP32Vec8(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v)}) {};
using VectorizedT = typename Base::VectorizedT;
using Vectorized2x4f = typename Base::NxVectorizedTArray;
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
FP32Vec8() : Base() {};
FP32Vec8(const FP32Vec8& data) : Base(data) {};
explicit FP32Vec8(float v) : Base(v) {};
explicit FP32Vec8(const float* ptr)
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
: Base(reinterpret_cast<const void*>(ptr)) {};
explicit FP32Vec8(const float* ptr, const int elem_num)
: Base(reinterpret_cast<const void*>(ptr), elem_num) {};
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const Vectorized2x4f& data) {
reg.val[0] = data.val[0];
reg.val[1] = data.val[1];
};
explicit FP32Vec8(const BF16Vec8& v) {
std::tie(reg.val[0], reg.val[1]) = convert_bfloat16_float(v.reg.val[0]);
};
explicit FP32Vec8(const FP16Vec8& v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[0])));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[0])));
};
explicit FP32Vec8(float16x8_t v) {
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v)));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v)));
};
explicit FP32Vec8(at_bfloat16x8_t v) {
std::tie(reg.val[0], reg.val[1]) =
convert_bfloat16_float(Vectorized<c10::BFloat16>(v));
};
explicit FP32Vec8(float32x4x2_t data) {
reg.val[0] = Vectorized<float>(data.val[0]);
reg.val[1] = Vectorized<float>(data.val[1]);
}
explicit FP32Vec8(float16x8_t v)
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec8(bfloat16x8_t v)
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
explicit FP32Vec8(const BF16Vec8& v)
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
#endif
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
FORCE_INLINE float reduce_sum() const noexcept {
float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&answer, &ar](int i) { answer += ar.values[i]; });
std::plus<VectorizedT> add;
unroll_loop<int, VEC_REG_NUM>([&](int i) {
answer += at::vec::vec_reduce_all<float, std::plus<VectorizedT>>(
add, reg.val[i]);
});
return answer;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
float32x2_t exp_vec0 = {expf(ar.values[0]), expf(ar.values[1])};
float32x2_t exp_vec1 = {expf(ar.values[2]), expf(ar.values[3])};
float32x2_t exp_vec2 = {expf(ar.values[4]), expf(ar.values[5])};
float32x2_t exp_vec3 = {expf(ar.values[6]), expf(ar.values[7])};
float32x4_t result0 = vcombine_f32(exp_vec0, exp_vec1);
float32x4_t result1 = vcombine_f32(exp_vec2, exp_vec3);
float32x4x2_t result;
result.val[0] = result0;
result.val[1] = result1;
return FP32Vec8(result);
FORCE_INLINE FP32Vec8 operator+(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] + b.reg.val[0];
r.reg.val[1] = reg.val[1] + b.reg.val[1];
return r;
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
float32x2_t tanh_vec0 = {tanhf(ar.values[0]), tanhf(ar.values[1])};
float32x2_t tanh_vec1 = {tanhf(ar.values[2]), tanhf(ar.values[3])};
float32x2_t tanh_vec2 = {tanhf(ar.values[4]), tanhf(ar.values[5])};
float32x2_t tanh_vec3 = {tanhf(ar.values[6]), tanhf(ar.values[7])};
float32x4_t result0 = vcombine_f32(tanh_vec0, tanh_vec1);
float32x4_t result1 = vcombine_f32(tanh_vec2, tanh_vec3);
float32x4x2_t result;
result.val[0] = result0;
result.val[1] = result1;
return FP32Vec8(result);
FORCE_INLINE FP32Vec8 operator-(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] - b.reg.val[0];
r.reg.val[1] = reg.val[1] - b.reg.val[1];
return r;
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
static_cast<float32_t>(erf(ar.values[1]))};
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
static_cast<float32_t>(erf(ar.values[3]))};
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
static_cast<float32_t>(erf(ar.values[5]))};
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
static_cast<float32_t>(erf(ar.values[7]))};
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
float32x4x2_t result;
result.val[0] = result0;
result.val[1] = result1;
return FP32Vec8(result);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[1], b.reg.val[1])}));
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[1], b.reg.val[1])}));
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[1], b.reg.val[1])}));
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[1], b.reg.val[1])}));
FORCE_INLINE FP32Vec8 operator*(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] * b.reg.val[0];
r.reg.val[1] = reg.val[1] * b.reg.val[1];
return r;
}
void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]);
FORCE_INLINE FP32Vec8 operator/(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] / b.reg.val[0];
r.reg.val[1] = reg.val[1] / b.reg.val[1];
return r;
}
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
int32x4x4_t reg;
struct FP32Vec16 : public VectorizedRegWrapper<FP32Vec16, 4, float> {
using Base = VectorizedRegWrapper<FP32Vec16, 4, float>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit INT32Vec16(const void* ptr) {
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
}
using ScalarT = typename Base::ScalarT;
using VectorizedT = typename Base::VectorizedT;
using Vectorized4x4f = typename Base::NxVectorizedTArray;
void save(int32_t* ptr) const {
vst1q_s32(ptr, reg.val[0]);
vst1q_s32(ptr + 4, reg.val[1]);
vst1q_s32(ptr + 8, reg.val[2]);
vst1q_s32(ptr + 12, reg.val[3]);
FP32Vec16() : Base() {};
FP32Vec16(const FP32Vec16& data) : Base(data) {};
explicit FP32Vec16(float v) : Base(v) {};
explicit FP32Vec16(const float* ptr)
: Base(reinterpret_cast<const void*>(ptr)) {};
explicit FP32Vec16(const float* ptr, const int elem_num)
: Base(reinterpret_cast<const void*>(ptr), elem_num) {};
explicit FP32Vec16(const Vectorized4x4f& data) {
reg.val[0] = data.val[0];
reg.val[1] = data.val[1];
reg.val[2] = data.val[2];
reg.val[3] = data.val[3];
};
void save(int32_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_s32(
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
int32x4_t temp = reg.val[full_blocks];
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
}
}
};
// ASIMD does not support non-temporal loads
explicit FP32Vec16(bool, const float* ptr) : Base(ptr) {}
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
float32x4x4_t reg;
float values[VEC_ELEM_NUM];
explicit FP32Vec16(float32x4x4_t data) {
reg.val[0] = data.val[0];
reg.val[1] = data.val[1];
reg.val[2] = data.val[2];
reg.val[3] = data.val[3];
};
float32x4x4_t reg;
explicit FP32Vec16(float v)
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
explicit FP32Vec16()
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
vmovq_n_f32(0.0)}) {}
explicit FP32Vec16(const float* ptr)
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
vld1q_f32(ptr + 12)}) {}
// ASIMD does not support non-temporal loads
explicit FP32Vec16(bool, const float* ptr) : FP32Vec16(ptr) {}
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[0];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[0];
};
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(bfloat16x8x2_t v)
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
#endif
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
};
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(const BF16Vec16& v)
: reg({vcvtq_low_f32_bf16(v.reg.val[0]),
vcvtq_high_f32_bf16(v.reg.val[0]),
vcvtq_low_f32_bf16(v.reg.val[1]),
vcvtq_high_f32_bf16(v.reg.val[1])}) {};
explicit FP32Vec16(const BF16Vec16& v) {
std::tie(reg.val[0], reg.val[1]) = convert_bfloat16_float(v.reg.val[0]);
std::tie(reg.val[2], reg.val[3]) = convert_bfloat16_float(v.reg.val[1]);
};
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
#endif
explicit FP32Vec16(const FP16Vec16& v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
};
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
};
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[1], b.reg.val[1]),
vaddq_f32(reg.val[2], b.reg.val[2]),
vaddq_f32(reg.val[3], b.reg.val[3])}));
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[0])));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[0])));
reg.val[2] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[1])));
reg.val[3] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[1])));
};
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[1], b.reg.val[1]),
vmulq_f32(reg.val[2], b.reg.val[2]),
vmulq_f32(reg.val[3], b.reg.val[3])}));
};
FORCE_INLINE FP32Vec16 operator+(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] + b.reg.val[0];
r.reg.val[1] = reg.val[1] + b.reg.val[1];
r.reg.val[2] = reg.val[2] + b.reg.val[2];
r.reg.val[3] = reg.val[3] + b.reg.val[3];
return r;
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[1], b.reg.val[1]),
vsubq_f32(reg.val[2], b.reg.val[2]),
vsubq_f32(reg.val[3], b.reg.val[3])}));
};
FORCE_INLINE FP32Vec16 operator-(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] - b.reg.val[0];
r.reg.val[1] = reg.val[1] - b.reg.val[1];
r.reg.val[2] = reg.val[2] - b.reg.val[2];
r.reg.val[3] = reg.val[3] - b.reg.val[3];
return r;
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[1], b.reg.val[1]),
vdivq_f32(reg.val[2], b.reg.val[2]),
vdivq_f32(reg.val[3], b.reg.val[3])}));
FORCE_INLINE FP32Vec16 operator*(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] * b.reg.val[0];
r.reg.val[1] = reg.val[1] * b.reg.val[1];
r.reg.val[2] = reg.val[2] * b.reg.val[2];
r.reg.val[3] = reg.val[3] * b.reg.val[3];
return r;
}
FORCE_INLINE FP32Vec16 operator/(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] / b.reg.val[0];
r.reg.val[1] = reg.val[1] / b.reg.val[1];
r.reg.val[2] = reg.val[2] / b.reg.val[2];
r.reg.val[3] = reg.val[3] / b.reg.val[3];
return r;
}
FORCE_INLINE FP32Vec16 clamp(const FP32Vec16& min,
const FP32Vec16& max) const {
FP32Vec16 r(uninit);
r.reg.val[0] = at::vec::clamp(reg.val[0], min.reg.val[0], max.reg.val[0]);
r.reg.val[1] = at::vec::clamp(reg.val[1], min.reg.val[1], max.reg.val[1]);
r.reg.val[2] = at::vec::clamp(reg.val[2], min.reg.val[2], max.reg.val[2]);
r.reg.val[3] = at::vec::clamp(reg.val[3], min.reg.val[3], max.reg.val[3]);
return r;
};
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(float32x4x4_t(
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
FORCE_INLINE FP32Vec16 min(const FP32Vec16& b) const {
FP32Vec16 r(uninit);
r.reg.val[0] = minimum(b.reg.val[0], reg.val[0]),
r.reg.val[1] = minimum(b.reg.val[1], reg.val[1]);
r.reg.val[2] = minimum(b.reg.val[2], reg.val[2]);
r.reg.val[3] = minimum(b.reg.val[3], reg.val[3]);
return r;
};
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
vmaxq_f32(b.reg.val[1], reg.val[1]),
vmaxq_f32(b.reg.val[2], reg.val[2]),
vmaxq_f32(b.reg.val[3], reg.val[3])}));
FORCE_INLINE FP32Vec16 max(const FP32Vec16& b) const {
FP32Vec16 r(uninit);
r.reg.val[0] = maximum(b.reg.val[0], reg.val[0]);
r.reg.val[1] = maximum(b.reg.val[1], reg.val[1]);
r.reg.val[2] = maximum(b.reg.val[2], reg.val[2]);
r.reg.val[3] = maximum(b.reg.val[3], reg.val[3]);
return r;
};
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
float32x4x4_t temp;
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
size_t num_elements = reg.val[0].size();
if (elem_num == VEC_ELEM_NUM) {
return FP32Vec16::min(b);
}
int full_blocks = elem_num / num_elements;
const int remainder = elem_num % num_elements;
FP32Vec16 res(uninit);
for (int i = 0; i < full_blocks; i++)
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
res.reg.val[i] = minimum(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
res.reg.val[full_blocks] =
vsetq_lane_f32(min_v, res.reg.val[full_blocks], 0);
}
if (remainder > 1) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
res.reg.val[full_blocks] =
vsetq_lane_f32(min_v, res.reg.val[full_blocks], 1);
}
if (remainder > 2) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
res.reg.val[full_blocks] =
vsetq_lane_f32(min_v, res.reg.val[full_blocks], 2);
}
return FP32Vec16(temp);
};
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({
vminq_f32(b.reg.val[0], reg.val[0]),
vminq_f32(b.reg.val[1], reg.val[1]),
vminq_f32(b.reg.val[2], reg.val[2]),
vminq_f32(b.reg.val[3], reg.val[3]),
}));
return res;
};
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
float32x4x4_t temp;
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
size_t num_elements = reg.val[0].size();
if (elem_num == VEC_ELEM_NUM) {
return FP32Vec16::max(b);
}
int full_blocks = elem_num / num_elements;
int remainder = elem_num % num_elements;
FP32Vec16 res(uninit);
for (int i = 0; i < full_blocks; i++)
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
res.reg.val[i] = maximum(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
res.reg.val[full_blocks] =
vsetq_lane_f32(max_v, res.reg.val[full_blocks], 0);
}
if (remainder > 1) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
res.reg.val[full_blocks] =
vsetq_lane_f32(max_v, res.reg.val[full_blocks], 1);
}
if (remainder > 2) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
res.reg.val[full_blocks] =
vsetq_lane_f32(max_v, res.reg.val[full_blocks], 2);
}
return FP32Vec16(temp);
};
FP32Vec16 abs() const {
return FP32Vec16(
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer;
return res;
};
float reduce_max() const {
AliasReg ar;
ar.reg = reg;
float max_v = std::numeric_limits<float>::lowest();
unroll_loop<int, VEC_ELEM_NUM>(
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
return max_v;
VectorizedT max_vec = reg.val[0];
unroll_loop<int, VEC_REG_NUM>([&](int i) {
if (i > 0) max_vec = maximum(max_vec, reg.val[i]);
});
return vmaxvq_f32(max_vec);
}
float reduce_min() const {
AliasReg ar;
ar.reg = reg;
float min_v = std::numeric_limits<float>::max();
unroll_loop<int, VEC_ELEM_NUM>(
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
return min_v;
VectorizedT min_vec = reg.val[0];
unroll_loop<int, VEC_REG_NUM>([&](int i) {
if (i > 0) min_vec = minimum(min_vec, reg.val[i]);
});
return vminvq_f32(min_vec);
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
AliasReg<NxVectorizedTArray, ScalarT, VEC_ELEM_NUM> ar{reg};
float answer = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&answer, &start, ar](int i) { answer += ar.values[start + i]; });
[&](int i) { answer += ar.values[start + i]; });
return answer;
};
void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]);
vst1q_f32(ptr + 8, reg.val[2]);
vst1q_f32(ptr + 12, reg.val[3]);
};
void save(float* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_f32(
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
float reduce_sum() const {
float answer = 0;
std::plus<VectorizedT> add;
unroll_loop<int, VEC_REG_NUM>([&](int i) {
answer += at::vec::vec_reduce_all<float>(add, reg.val[i]);
});
if (remainder > 0) {
float32x4_t temp = reg.val[full_blocks];
float* base = reinterpret_cast<float32_t*>(ptr) +
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
}
return answer;
}
};
// Only used for int types for now could be replaced when
// int8/32 vectorised ops are added in ATen
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
......@@ -854,30 +820,47 @@ struct INT8Vec64 : public Vec<INT8Vec64> {
void nt_save(int8_t* ptr) const { save(ptr); }
}; // INT8Vec64
template <typename T>
struct VecType {
using vec_type = void;
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
int32x4x4_t reg;
template <typename T>
using vec_t = typename VecType<T>::vec_type;
explicit INT32Vec16(const void* ptr) {
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
}
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
void save(int32_t* ptr) const {
vst1q_s32(ptr, reg.val[0]);
vst1q_s32(ptr + 4, reg.val[1]);
vst1q_s32(ptr + 8, reg.val[2]);
vst1q_s32(ptr + 12, reg.val[3]);
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
void save(int32_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
#ifdef ARM_BF16_SUPPORT
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
for (int i = 0; i < full_blocks; i++)
vst1q_s32(
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
int32x4_t temp = reg.val[full_blocks];
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
}
}
};
#endif
template <typename T>
void storeFP32(float v, T* ptr) {
......@@ -889,66 +872,55 @@ inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<__fp16*>(ptr) = v;
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
reg.val[0] = convert_float_half(v.reg.val[0], v.reg.val[1]);
};
reg.val[0] = vcombine_f16(low_0, high_0);
reg.val[1] = vcombine_f16(low_1, high_1);
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
reg.val[0] = convert_float_half(v.reg.val[0], v.reg.val[1]);
reg.val[1] = convert_float_half(v.reg.val[2], v.reg.val[3]);
};
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
fmadd(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
fmadd(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
fmadd(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
fmadd(acc.reg.val[3], a.reg.val[3], b.reg.val[3]);
};
reg = vcombine_f16(lower_half, upper_half);
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
reg.val[0] = convert_float_bfloat16(v.reg.val[0], v.reg.val[1]);
};
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]);
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
reg.val[0] = convert_float_bfloat16(v.reg.val[0], v.reg.val[1]);
reg.val[1] = convert_float_bfloat16(v.reg.val[2], v.reg.val[3]);
};
#ifdef ARM_BF16_SUPPORT
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
float32x4_t a1_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[1]));
float32x4_t b0_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[0]));
float32x4_t b0_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[0]));
float32x4_t b1_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[1]));
float32x4_t b1_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[1]));
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a0_low, b0_low);
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a0_high, b0_high);
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a1_low, b1_low);
acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a1_high, b1_high);
Vectorized<float> a0_low, a0_high, a1_low, a1_high, b0_low, b0_high, b1_low,
b1_high;
std::tie(a0_low, a0_high) = convert_bfloat16_float(a.reg.val[0]);
std::tie(a1_low, a1_high) = convert_bfloat16_float(a.reg.val[1]);
std::tie(b0_low, b0_high) = convert_bfloat16_float(b.reg.val[0]);
std::tie(b1_low, b1_high) = convert_bfloat16_float(b.reg.val[1]);
fmadd(acc.reg.val[0], a0_low, b0_low);
fmadd(acc.reg.val[1], a0_high, b0_high);
fmadd(acc.reg.val[2], a1_low, b1_low);
fmadd(acc.reg.val[3], a1_high, b1_high);
};
#endif
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
#ifdef ARM_BF16_SUPPORT
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
};
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
v.reg.val[3])}) {};
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
#else
*ptr = static_cast<c10::BFloat16>(v);
#endif
};
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
#ifdef ARM_BF16_SUPPORT
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
};
#endif
}; // namespace vec_op
\ No newline at end of file
......@@ -14,13 +14,11 @@ struct KernelVecType<float> {
using cvt_vec_type = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct KernelVecType<c10::BFloat16> {
using load_vec_type = vec_op::BF16Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
#endif
template <>
struct KernelVecType<c10::Half> {
......
......@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
using qk_vec_type = vec_op::BF16Vec32;
using v_load_vec_type = vec_op::BF16Vec16;
};
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// pass
#else
#elif defined(__aarch64__)
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16;
......
......@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
#if !defined(__powerpc__)
template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment