Commit fbeb8a6f authored by raojy's avatar raojy
Browse files

raw_vllm

parent 2ca8867f
Pipeline #3454 canceled with stages
#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP
#if defined(__x86_64__)
// x86 implementation
#include "cpu_types_x86.hpp"
#elif defined(__POWER9_VECTOR__)
// ppc implementation
#include "cpu_types_vsx.hpp"
#elif defined(__s390x__)
// s390 implementation
#include "cpu_types_vxe.hpp"
#elif defined(__aarch64__)
// arm implementation
#include "cpu_types_arm.hpp"
#else
#warning "unsupported vLLM cpu implementation, vLLM will compile with scalar"
#include "cpu_types_scalar.hpp"
#endif
#ifdef _OPENMP
#include <omp.h>
#endif
#endif
\ No newline at end of file
#include <cmath>
#include <type_traits>
#include <arm_neon.h>
#include <torch/all.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#if defined(__APPLE__)
#include "omp.h"
#endif
using namespace at::vec;
namespace vec_op {
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
// Number of elements in single ASIMD vector of given Datatype
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
};
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
inline constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T, typename... Ts>
struct is_one_of : std::bool_constant<(std::is_same_v<T, Ts> || ...)> {};
template <typename T, typename... Ts>
inline constexpr bool is_one_of_v = is_one_of<T, Ts...>::value;
struct uninit_t {
explicit constexpr uninit_t() = default;
};
inline constexpr uninit_t uninit{};
template <typename NxVectorizedTVecReg, typename T, int VEC_ELEM_NUM>
union AliasReg {
NxVectorizedTVecReg reg;
T values[VEC_ELEM_NUM];
};
// Template over at::vec::Vectorized<T> to support
// multiple vectorised registers into 1 of length VEC_REG_NUM val
template <int N, typename T>
struct NxVectorizedTVecReg {
using value_t = T;
using VectorizedT = Vectorized<T>;
VectorizedT val[N];
NxVectorizedTVecReg() = default;
NxVectorizedTVecReg(const NxVectorizedTVecReg&) = default;
NxVectorizedTVecReg(NxVectorizedTVecReg&&) = default;
NxVectorizedTVecReg& operator=(const NxVectorizedTVecReg&) = default;
NxVectorizedTVecReg& operator=(NxVectorizedTVecReg&&) = default;
explicit NxVectorizedTVecReg(uninit_t) noexcept {};
FORCE_INLINE explicit NxVectorizedTVecReg(const VectorizedT& vec_t) {
unroll_loop<int, N>([&](int i) { val[i] = vec_t; });
};
FORCE_INLINE explicit NxVectorizedTVecReg(T v) noexcept {
VectorizedT vv(v);
unroll_loop<int, N>([&](int i) { val[i] = vv; });
}
FORCE_INLINE explicit NxVectorizedTVecReg(const void* ptr) { load(ptr); }
explicit NxVectorizedTVecReg(const void* ptr, const int elem_num) {
load(ptr, elem_num);
}
static constexpr int size() noexcept { return N * VectorizedT::size(); }
FORCE_INLINE void save(void* ptr) const {
value_t* base = reinterpret_cast<value_t*>(ptr);
unroll_loop<int, N>(
[&](int i) { val[i].store(base + i * VectorizedT::size()); });
}
FORCE_INLINE void load(const void* ptr) {
const value_t* base = reinterpret_cast<const value_t*>(ptr);
unroll_loop<int, N>([&](int i) {
val[i] = VectorizedT::loadu(base + i * VectorizedT::size());
});
}
FORCE_INLINE void save(void* ptr, const int elem_num) const {
value_t* base = reinterpret_cast<value_t*>(ptr);
save_partial(base, elem_num);
}
FORCE_INLINE void load(const void* ptr, const int elem_num) {
const value_t* base = reinterpret_cast<const value_t*>(ptr);
load_partial(base, elem_num);
}
FORCE_INLINE void save_partial(value_t* base, int elem_num) const {
const int w = VectorizedT::size();
int full = elem_num / w;
int rem = elem_num % w;
for (int i = 0; i < full; i++) val[i].store(base + i * w);
if (rem) val[full].store(base + full * w, rem);
}
FORCE_INLINE void load_partial(const value_t* base, int elem_num) {
const int w = VectorizedT::size();
int full = elem_num / w;
int rem = elem_num % w;
for (int i = 0; i < full; i++) val[i] = VectorizedT::loadu(base + i * w);
if (rem) val[full] = VectorizedT::loadu(base + full * w, rem);
}
template <VectorizedT (VectorizedT::*torch_vec_func)() const,
value_t (*std_func)(value_t)>
FORCE_INLINE NxVectorizedTVecReg opt_vec_func_impl() const {
NxVectorizedTVecReg result;
if constexpr (torch_vec_func != nullptr) {
unroll_loop<int, N>(
[&](int i) { result.val[i] = (val[i].*torch_vec_func)(); });
} else {
for (int i = 0; i < N; i++) {
alignas(64) value_t buf[VectorizedT::size()];
val[i].store(buf);
for (int j = 0; j < VectorizedT::size(); ++j) {
buf[j] = std_func(buf[j]);
}
result.val[i] = VectorizedT::loadu(buf);
}
}
return result;
}
};
template <typename DerivedClassT, int N, typename T>
struct VectorizedRegWrapper {
using ScalarT = T;
using VectorizedT = Vectorized<T>;
using NxVectorizedTArray = NxVectorizedTVecReg<N, T>;
constexpr static int VEC_REG_NUM = N;
constexpr static int VEC_ELEM_NUM = VEC_REG_NUM * VectorizedT::size();
constexpr static int get_elem_num() { return VEC_ELEM_NUM; };
NxVectorizedTArray reg;
VectorizedRegWrapper() noexcept = default;
explicit VectorizedRegWrapper(uninit_t) noexcept : reg{uninit} {};
explicit VectorizedRegWrapper(T v) : reg(v) {};
explicit VectorizedRegWrapper(const void* ptr) : reg(ptr) {};
explicit VectorizedRegWrapper(const void* ptr, const int elem_num)
: reg(ptr, elem_num) {};
explicit VectorizedRegWrapper(const VectorizedT& r) : reg(r) {};
explicit VectorizedRegWrapper(const NxVectorizedTArray& r) : reg(r) {};
VectorizedRegWrapper(const VectorizedRegWrapper&) = default;
VectorizedRegWrapper(VectorizedRegWrapper&&) = default;
VectorizedRegWrapper& operator=(VectorizedRegWrapper&&) = default;
VectorizedRegWrapper& operator=(const VectorizedRegWrapper&) = default;
FORCE_INLINE void save(void* ptr) const { reg.save(ptr); }
void save(void* ptr, const int elem_num) const { reg.save(ptr, elem_num); }
// Define optimized functions using at::vec::Vectorized<T> where possible
// Fallback to std:: functions when not available
#define OPT_TORCH_IMPL(FUNC_NAME, STD_FUNC_NAME, TORCH_FUNC_NAME, ...) \
FORCE_INLINE DerivedClassT FUNC_NAME() const { \
if constexpr (is_one_of_v<T, __VA_ARGS__>) { \
return DerivedClassT{ \
reg.template opt_vec_func_impl<&VectorizedT::TORCH_FUNC_NAME, \
std::STD_FUNC_NAME>()}; \
} else { \
return DerivedClassT{reg.template opt_vec_func_impl< \
nullptr, static_cast<ScalarT (*)(ScalarT)>(&std::STD_FUNC_NAME)>()}; \
} \
}
// Define optimized functions for datatypes passed in __VA_ARGS__
OPT_TORCH_IMPL(abs, abs, abs, c10::Half, float)
OPT_TORCH_IMPL(er, erf, erf, float)
OPT_TORCH_IMPL(exp, exp, fexp_u20, float)
OPT_TORCH_IMPL(exp_u20, exp, exp_u20, float)
OPT_TORCH_IMPL(sin, sin, sin, float)
OPT_TORCH_IMPL(sinh, sinh, sinh, float)
OPT_TORCH_IMPL(cos, cos, cos, float)
OPT_TORCH_IMPL(cosh, cosh, cosh, float)
OPT_TORCH_IMPL(log, log, log, float)
OPT_TORCH_IMPL(log10, log10, log10, float)
OPT_TORCH_IMPL(sqrt, sqrt, sqrt, c10::Half, float)
OPT_TORCH_IMPL(tan, tan, tan, float)
OPT_TORCH_IMPL(tanh, tanh, tanh, float)
#undef OPT_TORCH_IMPL
};
// forward declare vectorised dtypes
struct FP32Vec8;
struct FP32Vec16;
struct FP16Vec8;
struct FP16Vec16;
struct BF16Vec8;
struct BF16Vec16;
struct INT8Vec16;
struct INT32Vec16;
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
struct FP16Vec8 : public VectorizedRegWrapper<FP16Vec8, 1, c10::Half> {
using Base = VectorizedRegWrapper<FP16Vec8, 1, c10::Half>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit FP16Vec8(const FP32Vec8&);
};
struct FP16Vec16 : public VectorizedRegWrapper<FP16Vec16, 2, c10::Half> {
using Base = VectorizedRegWrapper<FP16Vec16, 2, c10::Half>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
// ASIMD does not support non-temporal loads
explicit FP16Vec16(bool, const void* ptr) : Base(ptr) {}
explicit FP16Vec16(const FP32Vec16& vec);
};
struct BF16Vec8 : public VectorizedRegWrapper<BF16Vec8, 1, c10::BFloat16> {
using Base = VectorizedRegWrapper<BF16Vec8, 1, c10::BFloat16>;
using VectorizedT = typename Base::VectorizedT;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit BF16Vec8(at_bfloat16x8_t data) : Base(VectorizedT(data)) {};
explicit BF16Vec8(float32x4x2_t v) {
reg.val[0] = convert_float_bfloat16(v.val[0], v.val[1]);
};
explicit BF16Vec8(const FP32Vec8&);
};
struct BF16Vec16 : public VectorizedRegWrapper<BF16Vec16, 2, c10::BFloat16> {
using Base = VectorizedRegWrapper<BF16Vec16, 2, c10::BFloat16>;
using VectorizedT = typename Base::VectorizedT;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
// ASIMD does not support non-temporal loads
explicit BF16Vec16(bool, const void* ptr) : Base(ptr) {}
explicit BF16Vec16(float32x4x4_t v) {
reg.val[0] = convert_float_bfloat16(v.val[0], v.val[1]);
reg.val[1] = convert_float_bfloat16(v.val[2], v.val[3]);
};
explicit BF16Vec16(const FP32Vec16&);
};
struct BF16Vec32 : public VectorizedRegWrapper<BF16Vec32, 4, c10::BFloat16> {
using Base = VectorizedRegWrapper<BF16Vec32, 4, c10::BFloat16>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
explicit BF16Vec32(const BF16Vec8& vec8_data) {
reg.val[0] = vec8_data.reg.val[0];
reg.val[1] = vec8_data.reg.val[0];
reg.val[2] = vec8_data.reg.val[0];
reg.val[3] = vec8_data.reg.val[0];
};
};
struct FP32Vec4 : public VectorizedRegWrapper<FP32Vec4, 1, float> {
using Base = VectorizedRegWrapper<FP32Vec4, 1, float>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
using VectorizedT = typename Base::VectorizedT;
using Vectorized1x4f = typename Base::NxVectorizedTArray;
FP32Vec4() : Base() {};
explicit FP32Vec4(float v) : Base(v) {};
explicit FP32Vec4(float32x4_t data) : Base(VectorizedT(data)) {};
explicit FP32Vec4(const FP32Vec4& data) : Base(data) {};
};
struct FP32Vec8 : public VectorizedRegWrapper<FP32Vec8, 2, float> {
using Base = VectorizedRegWrapper<FP32Vec8, 2, float>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
using Base::VEC_REG_NUM;
using VectorizedT = typename Base::VectorizedT;
using Vectorized2x4f = typename Base::NxVectorizedTArray;
FP32Vec8() : Base() {};
FP32Vec8(const FP32Vec8& data) : Base(data) {};
explicit FP32Vec8(float v) : Base(v) {};
explicit FP32Vec8(const float* ptr)
: Base(reinterpret_cast<const void*>(ptr)) {};
explicit FP32Vec8(const float* ptr, const int elem_num)
: Base(reinterpret_cast<const void*>(ptr), elem_num) {};
explicit FP32Vec8(const Vectorized2x4f& data) {
reg.val[0] = data.val[0];
reg.val[1] = data.val[1];
};
explicit FP32Vec8(const BF16Vec8& v) {
std::tie(reg.val[0], reg.val[1]) = convert_bfloat16_float(v.reg.val[0]);
};
explicit FP32Vec8(const FP16Vec8& v) {
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[0])));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[0])));
};
explicit FP32Vec8(float16x8_t v) {
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v)));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v)));
};
explicit FP32Vec8(at_bfloat16x8_t v) {
std::tie(reg.val[0], reg.val[1]) =
convert_bfloat16_float(Vectorized<c10::BFloat16>(v));
};
explicit FP32Vec8(float32x4x2_t data) {
reg.val[0] = Vectorized<float>(data.val[0]);
reg.val[1] = Vectorized<float>(data.val[1]);
}
FORCE_INLINE float reduce_sum() const noexcept {
float answer = 0;
std::plus<VectorizedT> add;
unroll_loop<int, VEC_REG_NUM>([&](int i) {
answer += at::vec::vec_reduce_all<float, std::plus<VectorizedT>>(
add, reg.val[i]);
});
return answer;
}
FORCE_INLINE FP32Vec8 operator+(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] + b.reg.val[0];
r.reg.val[1] = reg.val[1] + b.reg.val[1];
return r;
}
FORCE_INLINE FP32Vec8 operator-(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] - b.reg.val[0];
r.reg.val[1] = reg.val[1] - b.reg.val[1];
return r;
}
FORCE_INLINE FP32Vec8 operator*(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] * b.reg.val[0];
r.reg.val[1] = reg.val[1] * b.reg.val[1];
return r;
}
FORCE_INLINE FP32Vec8 operator/(const FP32Vec8& b) const noexcept {
FP32Vec8 r(uninit);
r.reg.val[0] = reg.val[0] / b.reg.val[0];
r.reg.val[1] = reg.val[1] / b.reg.val[1];
return r;
}
};
struct FP32Vec16 : public VectorizedRegWrapper<FP32Vec16, 4, float> {
using Base = VectorizedRegWrapper<FP32Vec16, 4, float>;
using Base::Base;
using Base::get_elem_num;
using Base::VEC_ELEM_NUM;
using ScalarT = typename Base::ScalarT;
using VectorizedT = typename Base::VectorizedT;
using Vectorized4x4f = typename Base::NxVectorizedTArray;
FP32Vec16() : Base() {};
FP32Vec16(const FP32Vec16& data) : Base(data) {};
explicit FP32Vec16(float v) : Base(v) {};
explicit FP32Vec16(const float* ptr)
: Base(reinterpret_cast<const void*>(ptr)) {};
explicit FP32Vec16(const float* ptr, const int elem_num)
: Base(reinterpret_cast<const void*>(ptr), elem_num) {};
explicit FP32Vec16(const Vectorized4x4f& data) {
reg.val[0] = data.val[0];
reg.val[1] = data.val[1];
reg.val[2] = data.val[2];
reg.val[3] = data.val[3];
};
// ASIMD does not support non-temporal loads
explicit FP32Vec16(bool, const float* ptr) : Base(ptr) {}
explicit FP32Vec16(float32x4x4_t data) {
reg.val[0] = data.val[0];
reg.val[1] = data.val[1];
reg.val[2] = data.val[2];
reg.val[3] = data.val[3];
};
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[0];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[0];
};
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
};
explicit FP32Vec16(const BF16Vec16& v) {
std::tie(reg.val[0], reg.val[1]) = convert_bfloat16_float(v.reg.val[0]);
std::tie(reg.val[2], reg.val[3]) = convert_bfloat16_float(v.reg.val[1]);
};
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
explicit FP32Vec16(const FP16Vec16& v) {
reg.val[0] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[0])));
reg.val[1] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[0])));
reg.val[2] = Vectorized<float>(vcvt_f32_f16(vget_low_f16(v.reg.val[1])));
reg.val[3] = Vectorized<float>(vcvt_f32_f16(vget_high_f16(v.reg.val[1])));
};
FORCE_INLINE FP32Vec16 operator+(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] + b.reg.val[0];
r.reg.val[1] = reg.val[1] + b.reg.val[1];
r.reg.val[2] = reg.val[2] + b.reg.val[2];
r.reg.val[3] = reg.val[3] + b.reg.val[3];
return r;
}
FORCE_INLINE FP32Vec16 operator-(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] - b.reg.val[0];
r.reg.val[1] = reg.val[1] - b.reg.val[1];
r.reg.val[2] = reg.val[2] - b.reg.val[2];
r.reg.val[3] = reg.val[3] - b.reg.val[3];
return r;
}
FORCE_INLINE FP32Vec16 operator*(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] * b.reg.val[0];
r.reg.val[1] = reg.val[1] * b.reg.val[1];
r.reg.val[2] = reg.val[2] * b.reg.val[2];
r.reg.val[3] = reg.val[3] * b.reg.val[3];
return r;
}
FORCE_INLINE FP32Vec16 operator/(const FP32Vec16& b) const noexcept {
FP32Vec16 r(uninit);
r.reg.val[0] = reg.val[0] / b.reg.val[0];
r.reg.val[1] = reg.val[1] / b.reg.val[1];
r.reg.val[2] = reg.val[2] / b.reg.val[2];
r.reg.val[3] = reg.val[3] / b.reg.val[3];
return r;
}
FORCE_INLINE FP32Vec16 clamp(const FP32Vec16& min,
const FP32Vec16& max) const {
FP32Vec16 r(uninit);
r.reg.val[0] = at::vec::clamp(reg.val[0], min.reg.val[0], max.reg.val[0]);
r.reg.val[1] = at::vec::clamp(reg.val[1], min.reg.val[1], max.reg.val[1]);
r.reg.val[2] = at::vec::clamp(reg.val[2], min.reg.val[2], max.reg.val[2]);
r.reg.val[3] = at::vec::clamp(reg.val[3], min.reg.val[3], max.reg.val[3]);
return r;
};
FORCE_INLINE FP32Vec16 min(const FP32Vec16& b) const {
FP32Vec16 r(uninit);
r.reg.val[0] = minimum(b.reg.val[0], reg.val[0]),
r.reg.val[1] = minimum(b.reg.val[1], reg.val[1]);
r.reg.val[2] = minimum(b.reg.val[2], reg.val[2]);
r.reg.val[3] = minimum(b.reg.val[3], reg.val[3]);
return r;
};
FORCE_INLINE FP32Vec16 max(const FP32Vec16& b) const {
FP32Vec16 r(uninit);
r.reg.val[0] = maximum(b.reg.val[0], reg.val[0]);
r.reg.val[1] = maximum(b.reg.val[1], reg.val[1]);
r.reg.val[2] = maximum(b.reg.val[2], reg.val[2]);
r.reg.val[3] = maximum(b.reg.val[3], reg.val[3]);
return r;
};
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
size_t num_elements = reg.val[0].size();
if (elem_num == VEC_ELEM_NUM) {
return FP32Vec16::min(b);
}
int full_blocks = elem_num / num_elements;
const int remainder = elem_num % num_elements;
FP32Vec16 res(uninit);
for (int i = 0; i < full_blocks; i++)
res.reg.val[i] = minimum(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
res.reg.val[full_blocks] =
vsetq_lane_f32(min_v, res.reg.val[full_blocks], 0);
}
if (remainder > 1) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
res.reg.val[full_blocks] =
vsetq_lane_f32(min_v, res.reg.val[full_blocks], 1);
}
if (remainder > 2) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
res.reg.val[full_blocks] =
vsetq_lane_f32(min_v, res.reg.val[full_blocks], 2);
}
return res;
};
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
size_t num_elements = reg.val[0].size();
if (elem_num == VEC_ELEM_NUM) {
return FP32Vec16::max(b);
}
int full_blocks = elem_num / num_elements;
int remainder = elem_num % num_elements;
FP32Vec16 res(uninit);
for (int i = 0; i < full_blocks; i++)
res.reg.val[i] = maximum(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
res.reg.val[full_blocks] =
vsetq_lane_f32(max_v, res.reg.val[full_blocks], 0);
}
if (remainder > 1) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
res.reg.val[full_blocks] =
vsetq_lane_f32(max_v, res.reg.val[full_blocks], 1);
}
if (remainder > 2) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
res.reg.val[full_blocks] =
vsetq_lane_f32(max_v, res.reg.val[full_blocks], 2);
}
return res;
};
float reduce_max() const {
VectorizedT max_vec = reg.val[0];
unroll_loop<int, VEC_REG_NUM>([&](int i) {
if (i > 0) max_vec = maximum(max_vec, reg.val[i]);
});
return vmaxvq_f32(max_vec);
}
float reduce_min() const {
VectorizedT min_vec = reg.val[0];
unroll_loop<int, VEC_REG_NUM>([&](int i) {
if (i > 0) min_vec = minimum(min_vec, reg.val[i]);
});
return vminvq_f32(min_vec);
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg<NxVectorizedTArray, ScalarT, VEC_ELEM_NUM> ar{reg};
float answer = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&](int i) { answer += ar.values[start + i]; });
return answer;
};
float reduce_sum() const {
float answer = 0;
std::plus<VectorizedT> add;
unroll_loop<int, VEC_REG_NUM>([&](int i) {
answer += at::vec::vec_reduce_all<float>(add, reg.val[i]);
});
return answer;
}
};
// Only used for int types for now could be replaced when
// int8/32 vectorised ops are added in ATen
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int8x16_t reg;
int8_t values[VEC_ELEM_NUM];
};
int8x16_t reg;
explicit INT8Vec16(const FP32Vec16& vec) {
// Convert each 128-bit float32 vector to int32
int32x4_t part0 =
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
int32x4_t part1 =
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
int32x4_t part2 =
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
int32x4_t part3 =
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
// Narrow each 32-bit vector to 8 bits and combine
int8x8_t lower =
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
int8x8_t upper =
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
}
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
void save(int8_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
for (int i = 0; i < full_blocks; i++)
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
if (remainder > 0) {
int8x16_t temp = reg;
int8_t* base =
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
}
};
};
struct INT8Vec64 : public Vec<INT8Vec64> {
constexpr static int VEC_ELEM_NUM = 64;
union AliasReg {
int8x16x4_t reg;
int8_t values[VEC_ELEM_NUM];
};
int8x16x4_t reg;
explicit INT8Vec64(const int8_t* ptr) { reg = vld1q_s8_x4(ptr); }
// ASIMD does not support non-temporal loads
explicit INT8Vec64(bool, const int8_t* ptr) : INT8Vec64(ptr) {}
void save(int8_t* ptr) const { vst1q_s8_x4(ptr, reg); }
// masked store
void save(int8_t* p, int elem_num) const {
TORCH_CHECK(elem_num <= VEC_ELEM_NUM && elem_num > 0);
if (elem_num == VEC_ELEM_NUM) {
vst1q_s8_x4(p, reg);
return;
}
const int full_quadwords = elem_num / 16;
const int remaining_bytes = elem_num % 16;
for (int i = 0; i < full_quadwords; ++i) {
vst1q_s8(p + 16 * i, reg.val[i]);
}
if (remaining_bytes) {
const int8x16_t v = reg.val[full_quadwords];
int8_t* tail = p + 16 * full_quadwords;
switch (remaining_bytes) {
case 15:
tail[14] = vgetq_lane_s8(v, 14);
[[fallthrough]];
case 14:
tail[13] = vgetq_lane_s8(v, 13);
[[fallthrough]];
case 13:
tail[12] = vgetq_lane_s8(v, 12);
[[fallthrough]];
case 12:
tail[11] = vgetq_lane_s8(v, 11);
[[fallthrough]];
case 11:
tail[10] = vgetq_lane_s8(v, 10);
[[fallthrough]];
case 10:
tail[9] = vgetq_lane_s8(v, 9);
[[fallthrough]];
case 9:
tail[8] = vgetq_lane_s8(v, 8);
[[fallthrough]];
case 8:
tail[7] = vgetq_lane_s8(v, 7);
[[fallthrough]];
case 7:
tail[6] = vgetq_lane_s8(v, 6);
[[fallthrough]];
case 6:
tail[5] = vgetq_lane_s8(v, 5);
[[fallthrough]];
case 5:
tail[4] = vgetq_lane_s8(v, 4);
[[fallthrough]];
case 4:
tail[3] = vgetq_lane_s8(v, 3);
[[fallthrough]];
case 3:
tail[2] = vgetq_lane_s8(v, 2);
[[fallthrough]];
case 2:
tail[1] = vgetq_lane_s8(v, 1);
[[fallthrough]];
case 1:
tail[0] = vgetq_lane_s8(v, 0);
break;
default:
break;
}
}
}
// ASIMD does not support non-temporal stores
void nt_save(int8_t* ptr) const { save(ptr); }
}; // INT8Vec64
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
int32x4x4_t reg;
explicit INT32Vec16(const void* ptr) {
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
}
void save(int32_t* ptr) const {
vst1q_s32(ptr, reg.val[0]);
vst1q_s32(ptr + 4, reg.val[1]);
vst1q_s32(ptr + 8, reg.val[2]);
vst1q_s32(ptr + 12, reg.val[3]);
};
void save(int32_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_s32(
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
int32x4_t temp = reg.val[full_blocks];
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
}
}
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<__fp16*>(ptr) = v;
}
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
reg.val[0] = convert_float_half(v.reg.val[0], v.reg.val[1]);
};
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
reg.val[0] = convert_float_half(v.reg.val[0], v.reg.val[1]);
reg.val[1] = convert_float_half(v.reg.val[2], v.reg.val[3]);
};
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
fmadd(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
fmadd(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
fmadd(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
fmadd(acc.reg.val[3], a.reg.val[3], b.reg.val[3]);
};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
reg.val[0] = convert_float_bfloat16(v.reg.val[0], v.reg.val[1]);
};
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
reg.val[0] = convert_float_bfloat16(v.reg.val[0], v.reg.val[1]);
reg.val[1] = convert_float_bfloat16(v.reg.val[2], v.reg.val[3]);
};
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
Vectorized<float> a0_low, a0_high, a1_low, a1_high, b0_low, b0_high, b1_low,
b1_high;
std::tie(a0_low, a0_high) = convert_bfloat16_float(a.reg.val[0]);
std::tie(a1_low, a1_high) = convert_bfloat16_float(a.reg.val[1]);
std::tie(b0_low, b0_high) = convert_bfloat16_float(b.reg.val[0]);
std::tie(b1_low, b1_high) = convert_bfloat16_float(b.reg.val[1]);
fmadd(acc.reg.val[0], a0_low, b0_low);
fmadd(acc.reg.val[1], a0_high, b0_high);
fmadd(acc.reg.val[2], a1_low, b1_low);
fmadd(acc.reg.val[3], a1_high, b1_high);
};
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
#ifdef ARM_BF16_SUPPORT
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
#else
*ptr = static_cast<c10::BFloat16>(v);
#endif
};
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
}; // namespace vec_op
\ No newline at end of file
#include <cmath>
#include <cstdint>
#include <cstring>
#include <torch/all.h>
#include "float_convert.hpp"
namespace vec_op {
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
typedef struct f16x8_t {
uint16_t val[8];
} f16x8_t;
typedef struct f16x16_t {
uint16_t val[16];
} f16x16_t;
typedef struct f16x32_t {
uint16_t val[32];
} f16x32_t;
typedef struct f32x4_t {
float val[4];
} f32x4_t;
typedef struct f32x8_t {
float val[8];
} f32x8_t;
typedef struct f32x16_t {
float val[16];
} f32x16_t;
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
};
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T> > >
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
f16x8_t reg;
explicit FP16Vec8(const void* ptr)
: reg(*reinterpret_cast<const f16x8_t*>(ptr)) {};
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const { *reinterpret_cast<f16x8_t*>(ptr) = reg; }
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
f16x16_t reg;
explicit FP16Vec16(const void* ptr)
: reg(*reinterpret_cast<const f16x16_t*>(ptr)) {};
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
f16x8_t reg;
explicit BF16Vec8(const void* ptr)
: reg(*reinterpret_cast<const f16x8_t*>(ptr)) {};
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const { *reinterpret_cast<f16x8_t*>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
f16x16_t reg;
explicit BF16Vec16(const void* ptr)
: reg(*reinterpret_cast<const f16x16_t*>(ptr)) {};
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
f16x32_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const f16x32_t*>(ptr)) {};
explicit BF16Vec32(f16x32_t data) : reg(data) {};
explicit BF16Vec32(BF16Vec8& vec8_data) {
unroll_loop<int, VEC_ELEM_NUM>([&vec8_data, this](int i) {
reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM];
});
}
void save(void* ptr) const { *reinterpret_cast<f16x32_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
f32x4_t reg;
explicit FP32Vec4(float v) {
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec4() {
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec4(const float* ptr)
: reg(*reinterpret_cast<const f32x4_t*>(ptr)) {};
explicit FP32Vec4(f32x4_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
f32x8_t reg;
explicit FP32Vec8(float v) {
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec8() {
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec8(const float* ptr)
: reg(*reinterpret_cast<const f32x8_t*>(ptr)) {};
explicit FP32Vec8(f32x8_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
FP32Vec8(const BF16Vec8& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
float reduce_sum() const {
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result += reg.val[i]; });
return result;
}
FP32Vec8 exp() const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = expf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 tanh() const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 er() const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = erf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator+(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator-(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator/(const FP32Vec8& b) const {
f32x8_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
return FP32Vec8(ret);
}
void save(void* ptr) const { *reinterpret_cast<f32x8_t*>(ptr) = reg; }
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
f32x16_t reg;
explicit FP32Vec16(float v) {
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec16() {
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec16(const float* ptr)
: reg(*reinterpret_cast<const f32x16_t*>(ptr)) {};
explicit FP32Vec16(f32x16_t data) : reg(data) {};
FP32Vec16(const FP32Vec4& data) {
unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM];
});
}
FP32Vec16(const FP32Vec8& data) {
unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM];
});
}
FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const BF16Vec16& v) {
unroll_loop<int, VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16 operator*(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 operator+(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 operator-(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 operator/(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
return FP32Vec16(ret);
}
FP32Vec16 max(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
ret.val[i] = std::max(reg.val[i], b.reg.val[i]);
});
return FP32Vec16(ret);
}
FP32Vec16 min(const FP32Vec16& b) const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
ret.val[i] = std::min(reg.val[i], b.reg.val[i]);
});
return FP32Vec16(ret);
}
FP32Vec16 abs() const {
f32x16_t ret;
unroll_loop<int, VEC_ELEM_NUM>(
[&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); });
return FP32Vec16(ret);
}
float reduce_sum() const {
float result = 0.0f;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result += reg.val[i]; });
return result;
}
float reduce_max() const {
float result = std::numeric_limits<float>::lowest();
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result = std::max(reg.val[i], result); });
return result;
}
float reduce_min() const {
float result = std::numeric_limits<float>::max();
unroll_loop<int, VEC_ELEM_NUM>(
[&result, this](int i) { result = std::min(reg.val[i], result); });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
float sum = 0.0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&sum, &start, this](int i) { sum += reg.val[start + i]; });
return sum;
}
void save(void* ptr) const { *reinterpret_cast<f32x16_t*>(ptr) = reg; }
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
/*
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
c10::Half __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::Half *>(&v);
*ptr = *(v_ptr + 1);
}
*/
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
uint16_t fp16 = float_to_fp16(v);
*reinterpret_cast<uint16_t*>(ptr) = fp16;
}
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
unroll_loop<int, FP16Vec16::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
unroll_loop<int, FP16Vec8::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
unroll_loop<int, BF16Vec8::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
unroll_loop<int, BF16Vec16::VEC_ELEM_NUM>(
[&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }
}; // namespace vec_op
#ifndef CPU_TYPES_VSX_HPP
#define CPU_TYPES_VSX_HPP
#include <altivec.h>
#include <cmath>
#include <algorithm>
#include <torch/all.h>
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
typedef struct ss16x8x2_t {
__vector signed short val[2];
} ss16x8x2_t;
typedef struct ss16x8x4_t {
__vector signed short val[4];
} ss16x8x4_t;
typedef struct f32x4x2_t {
__vector float val[2];
} f32x4x2_t;
typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
typedef struct i32x4x4_t {
__vector int32_t val[4];
} i32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit BF16Vec8(const void* ptr)
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit BF16Vec16(const void* ptr) {
// Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
void save(void* ptr, const int elem_num) const {
const int clamped_elem = std::max(0, std::min(elem_num, 16));
// Calculate elements to store in each 128-bit part (8 elements each)
const int elements_val0 = std::min(clamped_elem, 8);
const int elements_val1 = std::max(clamped_elem - 8, 0);
// Convert elements to bytes (2 bytes per element)
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
signed short* dest = static_cast<signed short*>(ptr);
// Store the first part using vec_xst_len
if (bytes_val0 > 0) {
vec_xst_len(reg.val[0], dest, bytes_val0);
}
// Store the second part if needed
if (bytes_val1 > 0) {
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
}
}
};
const static __vector signed short zero = vec_splats((signed short)0);
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__vector float reg;
float values[VEC_ELEM_NUM];
};
__vector float reg;
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
f32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x2_t reg;
explicit FP32Vec8(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
}
explicit FP32Vec8() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
}
explicit FP32Vec8(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
}
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
}
explicit FP32Vec8(const BF16Vec8& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::exp(ar.values[0]);
ret.val[0][1] = std::exp(ar.values[1]);
ret.val[0][2] = std::exp(ar.values[2]);
ret.val[0][3] = std::exp(ar.values[3]);
ret.val[1][0] = std::exp(ar.values[4]);
ret.val[1][1] = std::exp(ar.values[5]);
ret.val[1][2] = std::exp(ar.values[6]);
ret.val[1][3] = std::exp(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 tanh() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::tanh(ar.values[0]);
ret.val[0][1] = std::tanh(ar.values[1]);
ret.val[0][2] = std::tanh(ar.values[2]);
ret.val[0][3] = std::tanh(ar.values[3]);
ret.val[1][0] = std::tanh(ar.values[4]);
ret.val[1][1] = std::tanh(ar.values[5]);
ret.val[1][2] = std::tanh(ar.values[6]);
ret.val[1][3] = std::tanh(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 er() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::erf(ar.values[0]);
ret.val[0][1] = std::erf(ar.values[1]);
ret.val[0][2] = std::erf(ar.values[2]);
ret.val[0][3] = std::erf(ar.values[3]);
ret.val[1][0] = std::erf(ar.values[4]);
ret.val[1][1] = std::erf(ar.values[5]);
ret.val[1][2] = std::erf(ar.values[6]);
ret.val[1][3] = std::erf(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
}
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
i32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
i32x4x4_t reg;
explicit INT32Vec16(const void* data_ptr) {
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[1] =
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[2] =
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[3] =
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
}
void save(int32_t* ptr) const {
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
}
void save(int32_t* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
f32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x4_t reg;
explicit FP32Vec16(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
reg.val[2] = vec_splats(v);
reg.val[3] = vec_splats(v);
}
explicit FP32Vec16() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
reg.val[2] = vec_splats(0.0f);
reg.val[3] = vec_splats(0.0f);
}
explicit FP32Vec16(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr);
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3];
}
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
}
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const BF16Vec16& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vec_ctf(v.reg.val[0], 0);
reg.val[1] = vec_ctf(v.reg.val[1], 0);
reg.val[2] = vec_ctf(v.reg.val[2], 0);
reg.val[3] = vec_ctf(v.reg.val[3], 0);
}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(f32x4x4_t(
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
// Create a vector of element indices for each chunk
__vector unsigned int indices = {0, 1, 2, 3};
__vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
// Compute masks for each chunk
__vector unsigned int chunk_offset0 = {0, 0, 0,
0}; // Chunk 0: Elements 0-3
__vector unsigned int chunk_offset1 = {4, 4, 4,
4}; // Chunk 1: Elements 4-7
__vector unsigned int chunk_offset2 = {8, 8, 8,
8}; // Chunk 2: Elements 8-11
__vector unsigned int chunk_offset3 = {12, 12, 12,
12}; // Chunk 3: Elements 12-15
// Compute masks for each chunk
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
// Apply masks to compute the result for each chunk
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
vec_min(reg.val[1], b.reg.val[1]),
vec_min(reg.val[2], b.reg.val[2]),
vec_min(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
vector unsigned int indices = {0, 1, 2, 3};
vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 abs() const {
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
}
float reduce_max() {
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
__vector float max_all = vec_max(max01, max23);
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
temp = vec_max(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_min() {
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
__vector float min_all = vec_min(min01, min23);
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
temp = vec_min(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float result = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&result, &start, ar](int i) { result += ar.values[start + i]; });
return result;
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
void save(float* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
union AliasReg {
__vector signed char reg;
int8_t values[VEC_NUM_ELEM];
};
__vector signed char reg;
explicit INT8Vec16(const FP32Vec16& vec) {
__vector signed int ret[4];
ret[0] = vec_cts(vec.reg.val[0], 0);
ret[1] = vec_cts(vec.reg.val[1], 0);
ret[2] = vec_cts(vec.reg.val[2], 0);
ret[3] = vec_cts(vec.reg.val[3], 0);
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
reg = vec_packs(packed1, packed2);
}
void save(void* ptr) const {
*reinterpret_cast<__vector signed char*>(ptr) = reg;
}
void save(signed char* ptr, const int elem_num) {
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
}
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
16, 17, 20, 21, 24, 25, 28, 29};
#ifndef _ARCH_PWR10
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
0x00007fff};
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
0x7fc00000};
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
#endif
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
#ifdef _ARCH_PWR10
__vector signed short ret[2];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[1]);
reg = vec_perm(ret[0], ret[1], omask);
#elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int lsb0 = vec_sr(inp0, sh16);
__vector unsigned int lsb1 = vec_sr(inp1, sh16);
lsb0 = vec_and(lsb0, one);
lsb1 = vec_and(lsb1, one);
__vector unsigned int rnd0 = vec_add(lsb0, bias);
__vector unsigned int rnd1 = vec_add(lsb1, bias);
inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1);
__vector __bool int sel0 =
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 =
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp0 = vec_sr(inp0, sh16);
inp1 = vec_sr(inp1, sh16);
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
#endif
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
#ifdef _ARCH_PWR10
__vector signed short ret[4];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[1]);
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[2]);
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[3]);
reg.val[0] = vec_perm(ret[0], ret[1], omask);
reg.val[1] = vec_perm(ret[2], ret[3], omask);
#elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
__vector unsigned int lsb0 = vec_sr(inp0, sh16);
__vector unsigned int lsb1 = vec_sr(inp1, sh16);
__vector unsigned int lsb2 = vec_sr(inp2, sh16);
__vector unsigned int lsb3 = vec_sr(inp3, sh16);
lsb0 = vec_and(lsb0, one);
lsb1 = vec_and(lsb1, one);
lsb2 = vec_and(lsb2, one);
lsb3 = vec_and(lsb3, one);
__vector unsigned int rnd0 = vec_add(lsb0, bias);
__vector unsigned int rnd1 = vec_add(lsb1, bias);
__vector unsigned int rnd2 = vec_add(lsb2, bias);
__vector unsigned int rnd3 = vec_add(lsb3, bias);
inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1);
inp2 = vec_add(inp2, rnd2);
inp3 = vec_add(inp3, rnd3);
__vector __bool int sel0 =
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 =
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
__vector __bool int sel2 =
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
__vector __bool int sel3 =
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2);
inp3 = vec_sel(inp3, nan, sel3);
inp0 = vec_sr(inp0, sh16);
inp1 = vec_sr(inp1, sh16);
inp2 = vec_sr(inp2, sh16);
inp3 = vec_sr(inp3, sh16);
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
#endif
}
inline void prefetch(const void* addr) {
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
}
}; // namespace vec_op
#endif
#ifndef CPU_TYPES_VXE_HPP
#define CPU_TYPES_VXE_HPP
#include <vecintrin.h>
#include <cmath>
#include <limits>
#include <torch/all.h>
namespace vec_op {
#define vec_neg(a) (-(a))
#define vec_add(a, b) ((a) + (b))
#define vec_sub(a, b) ((a) - (b))
#define vec_mul(a, b) ((a) * (b))
#define vec_div(a, b) ((a) / (b))
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
// NOTE: FP16 (Half) is supported on s390x via custom bit-manipulation
// conversion. PyTorch itself lacks native s390x FP16 support.
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
typedef struct ss16x8x2_t {
__vector signed short val[2];
} ss16x8x2_t;
typedef struct ss16x8x4_t {
__vector signed short val[4];
} ss16x8x4_t;
typedef struct f32x4x2_t {
__vector float val[2];
} f32x4x2_t;
typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit FP16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const {
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit FP16Vec16(const void* ptr) {
// Load 256 bits (16 FP16 values) in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit BF16Vec16(const void* ptr) {
// Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
};
const static __vector signed short zero = vec_splats((signed short)0);
FORCE_INLINE __vector float fp16_to_fp32_bits(__vector unsigned int x) {
const __vector unsigned int mask_sign = {0x8000, 0x8000, 0x8000, 0x8000};
const __vector unsigned int mask_exp = {0x7C00, 0x7C00, 0x7C00, 0x7C00};
const __vector unsigned int mask_mant = {0x03FF, 0x03FF, 0x03FF, 0x03FF};
const __vector unsigned int bias_adj = {112, 112, 112, 112};
const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F,
0x1F}; // FP16 NaN/Inf exponent
const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF,
0xFF}; // FP32 NaN/Inf exponent
__vector unsigned int s = (x & mask_sign) << 16;
__vector unsigned int e = (x & mask_exp) >> 10;
__vector unsigned int m = (x & mask_mant) << 13;
// Check for NaN/Inf: exponent = 0x1F in FP16
__vector __bool int is_nan_inf = vec_cmpeq(e, exp_max_fp16);
// Normal: adjust bias; NaN/Inf: set to 0xFF
__vector unsigned int e_normal = e + bias_adj;
e = vec_sel(e_normal, exp_max_fp32, is_nan_inf);
return (__vector float)(s | (e << 23) | m);
}
FORCE_INLINE __vector unsigned int fp32_to_fp16_bits(__vector float f_in) {
__vector unsigned int in = (__vector unsigned int)f_in;
const __vector unsigned int mask_sign_32 = {0x80000000, 0x80000000,
0x80000000, 0x80000000};
const __vector unsigned int mask_exp_32 = {0x7F800000, 0x7F800000, 0x7F800000,
0x7F800000};
const __vector unsigned int mask_mant_32 = {0x007FFFFF, 0x007FFFFF,
0x007FFFFF, 0x007FFFFF};
// Use SIGNED integers for exponent math to handle underflow check
const __vector signed int bias_adj = {112, 112, 112, 112};
const __vector signed int zero = {0, 0, 0, 0};
const __vector signed int max_exp = {31, 31, 31, 31}; // Max FP16 exp
const __vector unsigned int exp_max_fp32 = {0xFF, 0xFF, 0xFF, 0xFF};
const __vector unsigned int exp_max_fp16 = {0x1F, 0x1F, 0x1F, 0x1F};
__vector unsigned int s = (in & mask_sign_32) >> 16;
__vector unsigned int e_u = (in & mask_exp_32) >> 23;
// Check for NaN/Inf: exponent = 0xFF in FP32
__vector __bool int is_nan_inf = vec_cmpeq(e_u, exp_max_fp32);
__vector signed int e_s = (__vector signed int)e_u;
e_s = vec_sub(e_s, bias_adj);
e_s = vec_max(e_s, zero);
e_s = vec_min(e_s, max_exp);
__vector unsigned int e_normal = (__vector unsigned int)e_s;
__vector unsigned int e_final = vec_sel(e_normal, exp_max_fp16, is_nan_inf);
const __vector unsigned int one_v = {1, 1, 1, 1};
const __vector unsigned int mask_sticky = {0xFFF, 0xFFF, 0xFFF, 0xFFF};
__vector unsigned int round_bit = (in >> 12) & one_v;
__vector unsigned int sticky = in & mask_sticky;
__vector unsigned int m = (in & mask_mant_32) >> 13;
__vector unsigned int lsb = m & one_v; // LSB of mantissa for tie-breaking
// Round up if: round_bit && (sticky || lsb)
__vector __bool int sticky_nonzero =
vec_cmpgt(sticky, (__vector unsigned int){0, 0, 0, 0});
__vector __bool int lsb_set = vec_cmpeq(lsb, one_v);
__vector __bool int round_up =
vec_and(vec_cmpeq(round_bit, one_v), vec_or(sticky_nonzero, lsb_set));
m = vec_sel(m, m + one_v, round_up);
const __vector unsigned int mant_mask = {0x3FF, 0x3FF, 0x3FF, 0x3FF};
const __vector unsigned int max_normal_exp = {0x1E, 0x1E, 0x1E, 0x1E};
__vector __bool int mant_overflows = vec_cmpgt(m, mant_mask);
__vector __bool int would_overflow_to_inf =
vec_and(mant_overflows, vec_cmpeq(e_final, max_normal_exp));
__vector unsigned int e_inc = vec_min(e_final + one_v, exp_max_fp16);
e_final = vec_sel(e_final, e_inc, mant_overflows);
m = vec_and(m, mant_mask);
e_final = vec_sel(e_final, max_normal_exp, would_overflow_to_inf);
m = vec_sel(m, mant_mask, would_overflow_to_inf);
return s | (e_final << 10) | m;
}
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8& vec8_data)
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__vector float reg;
float values[VEC_ELEM_NUM];
};
__vector float reg;
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
f32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x2_t reg;
explicit FP32Vec8(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
}
explicit FP32Vec8() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
}
explicit FP32Vec8(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
}
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
}
explicit FP32Vec8(const BF16Vec8& v) {
// On big-endian s390x, place BF16 first to get correct byte order
reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
explicit FP32Vec8(const FP16Vec8& v) {
// Cast to UNSIGNED short vector to prevent sign-extension during unpack
__vector unsigned short raw_u = (__vector unsigned short)v.reg;
// Unpack 8x16-bit to two 4x32-bit vectors (Zero extended)
__vector unsigned int raw_hi = (__vector unsigned int)vec_unpackh(raw_u);
__vector unsigned int raw_lo = (__vector unsigned int)vec_unpackl(raw_u);
reg.val[0] = fp16_to_fp32_bits(raw_hi);
reg.val[1] = fp16_to_fp32_bits(raw_lo);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
f32x4x2_t out;
const __vector float log2e = vec_splats(1.44269504088896341f);
const __vector float one = vec_splats(1.0f);
const __vector float min_x = vec_splats(-87.3f);
const __vector float max_x = vec_splats(88.7f);
// 5th-degree minimax polynomial for 2^r (r in [0,1))
const __vector float c1 = vec_splats(0.6931471805599453f);
const __vector float c2 = vec_splats(0.240226506959101f);
const __vector float c3 = vec_splats(0.05550410866482158f);
const __vector float c4 = vec_splats(0.009618129107628477f);
const __vector float c5 = vec_splats(0.0013333558146428443f);
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
x = vec_max(x, min_x);
x = vec_min(x, max_x);
__vector float y = vec_mul(x, log2e);
__vector float kf = vec_floor(y);
__vector float r = vec_sub(y, kf);
__vector signed int k = vec_signed(kf);
const __vector signed int min_k = vec_splats((signed int)-126);
const __vector signed int max_k = vec_splats((signed int)127);
k = vec_min(vec_max(k, min_k), max_k);
// Build 2^k from exponent bits
__vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
__vector unsigned int bits = (__vector unsigned int)exp_int;
bits = vec_sl(bits, vec_splats((unsigned int)23));
__vector float pow2k = (__vector float)bits;
// Improved minimax polynomial
__vector float poly = vec_madd(c5, r, c4);
poly = vec_madd(poly, r, c3);
poly = vec_madd(poly, r, c2);
poly = vec_madd(poly, r, c1);
poly = vec_madd(poly, r, one);
out.val[i] = vec_mul(pow2k, poly);
}
return FP32Vec8(out);
}
FP32Vec8 tanh() const {
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
const __vector float one = vec_splats(1.0f);
const __vector float two = vec_splats(2.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float sat =
vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
f32x4x2_t out;
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
__vector float ax = vec_abs(x);
// sign(x): +1 or -1
__vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
// saturation mask: |x| > sat
__vector __bool int saturated = vec_cmpgt(ax, sat);
// 2x
__vector float two_x = vec_mul(x, two);
// Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
f32x4x2_t tmp;
tmp.val[0] = two_x;
tmp.val[1] = two_x;
FP32Vec8 exp_2x_vec(tmp);
FP32Vec8 e2x = exp_2x_vec.exp();
__vector float e = e2x.reg.val[i];
// tanh(x) = (e - 1) / (e + 1)
__vector float num = vec_sub(e, one);
__vector float den = vec_add(e, one);
__vector float t = vec_div(num, den);
// For large |x|, clamp to sign(x)
out.val[i] = vec_sel(t, sign, saturated);
}
return FP32Vec8(out);
}
FP32Vec8 er() const {
// A&S 7.1.26 approximation:
// erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
// exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
const __vector float one = vec_splats(1.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float p = vec_splats(0.3275911f);
// Polynomial coeffs
const __vector float a1 = vec_splats(0.254829592f);
const __vector float a2 = vec_splats(-0.284496736f);
const __vector float a3 = vec_splats(1.421413741f);
const __vector float a4 = vec_splats(-1.453152027f);
const __vector float a5 = vec_splats(1.061405429f);
// Threshold where erf(x) ~ sign(x)
const __vector float sat = vec_splats(6.0f);
f32x4x2_t out;
for (int lane = 0; lane < 2; lane++) {
__vector float x = reg.val[lane];
__vector float ax = vec_abs(x);
// sign(x)
__vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
// |x| > 6 → erf(x) = ±1
__vector __bool int saturated = vec_cmpgt(ax, sat);
// t = 1 / (1 + p * |x|)
__vector float t = vec_madd(p, ax, one);
t = vec_div(one, t);
// poly = a5
__vector float poly = a5;
poly = vec_madd(poly, t, a4);
poly = vec_madd(poly, t, a3);
poly = vec_madd(poly, t, a2);
poly = vec_madd(poly, t, a1);
// full polynomial: poly = poly * t
poly = vec_mul(poly, t);
// Compute exp(-x^2)
__vector float x2 = vec_mul(x, x);
__vector float neg_x2 = vec_neg(x2);
f32x4x2_t tmp;
tmp.val[0] = neg_x2;
tmp.val[1] = neg_x2;
FP32Vec8 exp_neg_x2(tmp);
FP32Vec8 e = exp_neg_x2.exp();
__vector float ex = e.reg.val[lane];
// erf(x) = sign * (1 - poly * exp(-x^2))
__vector float term = vec_mul(poly, ex);
__vector float y = vec_sub(one, term);
y = vec_mul(y, sign);
// saturated → ±1
__vector float sat_val = vec_mul(sign, one);
out.val[lane] = vec_sel(y, sat_val, saturated);
}
return FP32Vec8(out);
}
// Elementwise sigmoid(x) = 1 / (1 + exp(-x))
FP32Vec8 sigmoid() const {
const __vector float one = vec_splats(1.0f);
f32x4x2_t neg;
for (int i = 0; i < 2; ++i) {
neg.val[i] = vec_neg(reg.val[i]);
}
FP32Vec8 neg_x(neg);
FP32Vec8 e = neg_x.exp(); // exp(-x)
f32x4x2_t denom;
for (int i = 0; i < 2; ++i) {
denom.val[i] = vec_add(one, e.reg.val[i]);
}
FP32Vec8 denom_vec(denom);
FP32Vec8 one_vec(1.0f);
return one_vec / denom_vec;
}
// Tanh-based GELU:
// gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
FP32Vec8 gelu_tanh() const {
const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
const __vector float k_0_0447 = vec_splats(0.044715f);
f32x4x2_t x2, x3, inner;
for (int i = 0; i < 2; ++i) {
__vector float x = reg.val[i];
x2.val[i] = vec_mul(x, x); // x^2
x3.val[i] = vec_mul(x2.val[i], x); // x^3
__vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
}
FP32Vec8 inner_vec(inner);
FP32Vec8 t = inner_vec.tanh(); // tanh part
FP32Vec8 one_vec(1.0f);
FP32Vec8 half_vec(0.5f);
FP32Vec8 x_vec(*this);
return x_vec * half_vec * (one_vec + t);
}
// Erf-based GELU:
// gelu(x) = 0.5 * x * (1 + erf(x / √2))
FP32Vec8 gelu_erf() const {
const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
FP32Vec8 x_vec(*this);
f32x4x2_t scaled;
for (int i = 0; i < 2; ++i) {
scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
}
FP32Vec8 x_scaled(scaled);
FP32Vec8 erf_x = x_scaled.er();
FP32Vec8 one_vec(1.0f);
FP32Vec8 half_vec(0.5f);
return x_vec * half_vec * (one_vec + erf_x);
}
// Elementwise reciprocal: 1/x (scalar per lane, for correctness)
FP32Vec8 rcp() const {
AliasReg in, out;
in.reg = reg;
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
out.values[i] = 1.0f / in.values[i];
}
return FP32Vec8(out.reg);
}
// Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
FP32Vec8 rsqrt() const {
AliasReg in, out;
in.reg = reg;
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
out.values[i] = 1.0f / std::sqrt(in.values[i]);
}
return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
f32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x4_t reg;
explicit FP32Vec16(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
reg.val[2] = vec_splats(v);
reg.val[3] = vec_splats(v);
}
explicit FP32Vec16() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
reg.val[2] = vec_splats(0.0f);
reg.val[3] = vec_splats(0.0f);
}
explicit FP32Vec16(const float* ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr);
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3];
}
explicit FP32Vec16(const FP32Vec4& data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
}
explicit FP32Vec16(const FP32Vec8& data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const BF16Vec16& v) {
// On big-endian s390x, place BF16 first to get correct byte order
reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const FP16Vec16& v) {
__vector unsigned int raw_hi_0 =
(__vector unsigned int)vec_unpackh(v.reg.val[0]);
__vector unsigned int raw_lo_0 =
(__vector unsigned int)vec_unpackl(v.reg.val[0]);
reg.val[0] = fp16_to_fp32_bits(raw_hi_0);
reg.val[1] = fp16_to_fp32_bits(raw_lo_0);
__vector unsigned int raw_hi_1 =
(__vector unsigned int)vec_unpackh(v.reg.val[1]);
__vector unsigned int raw_lo_1 =
(__vector unsigned int)vec_unpackl(v.reg.val[1]);
reg.val[2] = fp16_to_fp32_bits(raw_hi_1);
reg.val[3] = fp16_to_fp32_bits(raw_lo_1);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>(
[&result, &ar](int i) { result += ar.values[i]; });
return result;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float result = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&result, &start, ar](int i) { result += ar.values[start + i]; });
return result;
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
float reduce_max() const {
AliasReg ar;
ar.reg = reg;
float result = ar.values[0];
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) {
if (ar.values[i] > result) result = ar.values[i];
});
return result;
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
};
template <typename T>
struct VecType {
using vec_type = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
// value.
};
} // namespace c10
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1);
}
template <>
inline void storeFP32<::c10::Half>(float v, ::c10::Half* ptr) {
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
// produce incorrect results for some inputs. Process each of the 4 vectors
// separately.
uint32_t in;
std::memcpy(&in, &v, sizeof(in));
uint32_t s = (in & 0x80000000) >> 16; // Sign
uint32_t e = (in & 0x7F800000) >> 23; // Exponent
uint32_t round_bit = (in >> 12) & 1;
uint32_t sticky = (in & 0xFFF) != 0; // Any bits in [11..0]
uint32_t m = (in & 0x007FFFFF) >> 13;
uint32_t lsb = m & 1; // LSB of mantissa for tie-breaking
// Check for NaN/Inf before rounding
bool is_nan_inf = (e == 0xFF);
if (round_bit && (sticky || lsb)) {
m++;
// Handle mantissa overflow: if m overflows 10 bits, increment exponent
if (m > 0x3FF) {
m = 0;
e++;
}
}
if (is_nan_inf) {
// NaN/Inf: preserve it
e = 0x1F;
} else {
// Normal: adjust bias (127 - 15), flush subnormals to zero
e = (e >= 112) ? (e - 112) : 0;
// If exponent overflows to Inf range, saturate to max normal FP16 value
if (e > 0x1E) {
e = 0x1E; // Max normal exponent
m = 0x3FF; // Max mantissa
}
}
uint16_t fp16 = (uint16_t)(s | (e << 10) | m);
*reinterpret_cast<uint16_t*>(ptr) = fp16;
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
// intrinsics
// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_madd(a.reg, b.reg, acc.reg);
}
// FP32Vec8 FMA: acc = acc + (a * b)
FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
// FP32Vec16 FMA: acc = acc + (a * b)
FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Multiply-Subtract: acc = acc - (a * b)
FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_msub(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Negative Multiply-Add: acc = -(a * b) + acc
FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Negative Multiply-Subtract: acc = -(a * b) - acc
FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
0x00007fff};
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
0x7fc00000};
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int lsb0 = inp0 >> sh16;
__vector unsigned int lsb1 = inp1 >> sh16;
lsb0 = lsb0 & one;
lsb1 = lsb1 & one;
__vector unsigned int rnd0 = lsb0 + bias;
__vector unsigned int rnd1 = lsb1 + bias;
inp0 = inp0 + rnd0;
inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp0 = inp0 >> sh16;
inp1 = inp1 >> sh16;
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
__vector unsigned int lsb0 = inp0 >> sh16;
__vector unsigned int lsb1 = inp1 >> sh16;
__vector unsigned int lsb2 = inp2 >> sh16;
__vector unsigned int lsb3 = inp3 >> sh16;
lsb0 = lsb0 & one;
lsb1 = lsb1 & one;
lsb2 = lsb2 & one;
lsb3 = lsb3 & one;
__vector unsigned int rnd0 = lsb0 + bias;
__vector unsigned int rnd1 = lsb1 + bias;
__vector unsigned int rnd2 = lsb2 + bias;
__vector unsigned int rnd3 = lsb3 + bias;
inp0 = inp0 + rnd0;
inp1 = inp1 + rnd1;
inp2 = inp2 + rnd2;
inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel2 =
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2);
inp3 = vec_sel(inp3, nan, sel3);
inp0 = inp0 >> sh16;
inp1 = inp1 >> sh16;
inp2 = inp2 >> sh16;
inp3 = inp3 >> sh16;
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
// produce incorrect results for some inputs. Process each of the 4 vectors
// separately.
__vector unsigned int res_hi = fp32_to_fp16_bits(v.reg.val[0]);
__vector unsigned int res_lo = fp32_to_fp16_bits(v.reg.val[1]);
const __vector unsigned char perm_pack = {
2, 3, 6, 7, 10, 11, 14, 15, // Select lower 2 bytes from res_hi
18, 19, 22, 23, 26, 27, 30, 31 // Select lower 2 bytes from res_lo
};
reg = vec_perm((__vector signed short)res_hi, (__vector signed short)res_lo,
perm_pack);
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
// Use bit-manipulation for IEEE FP32 to FP16 conversion since vector
// intrinsics for FP32 to FP16 conversion does not use IEEE rounding and can
// produce incorrect results for some inputs. Process each of the 4 vectors
// separately.
__vector unsigned int res_0 = fp32_to_fp16_bits(v.reg.val[0]);
__vector unsigned int res_1 = fp32_to_fp16_bits(v.reg.val[1]);
__vector unsigned int res_2 = fp32_to_fp16_bits(v.reg.val[2]);
__vector unsigned int res_3 = fp32_to_fp16_bits(v.reg.val[3]);
const __vector unsigned char perm_pack = {
2, 3, 6, 7, 10, 11, 14, 15, // Lower 2 bytes from first vector
18, 19, 22, 23, 26, 27, 30, 31 // Lower 2 bytes from second vector
};
reg.val[0] = vec_perm((__vector signed short)res_0,
(__vector signed short)res_1, perm_pack);
reg.val[1] = vec_perm((__vector signed short)res_2,
(__vector signed short)res_3, perm_pack);
}
// 1D softmax over `n` elements in `input`, writes result to `output`.
// Uses FP32Vec8 for main body, scalar tail handling.
// Requirement: n > 0
FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
if (n <= 0) return;
// ---------- Pass 1: find max ----------
float max_val = -std::numeric_limits<float>::infinity();
int i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 v(input + i);
FP32Vec8::AliasReg ar;
ar.reg = v.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
if (ar.values[j] > max_val) max_val = ar.values[j];
}
}
for (; i < n; ++i) {
if (input[i] > max_val) max_val = input[i];
}
// ---------- Pass 2: compute exp(x - max) and sum ----------
float sum = 0.0f;
i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
float tmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
tmp[j] = input[i + j] - max_val;
}
FP32Vec8 v(tmp);
FP32Vec8 e = v.exp();
FP32Vec8::AliasReg ar;
ar.reg = e.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
output[i + j] = ar.values[j];
sum += ar.values[j];
}
}
// Tail
for (; i < n; ++i) {
float x = input[i] - max_val;
float ex = std::exp(x); // scalar tail
output[i] = ex;
sum += ex;
}
// ---------- Pass 3: normalize ----------
float inv_sum = 1.0f / sum;
i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
float tmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
tmp[j] = output[i + j] * inv_sum;
}
FP32Vec8 v(tmp);
v.save(output + i);
}
for (; i < n; ++i) {
output[i] *= inv_sum;
}
}
// 1D RMSNorm kernel:
// input: x[0..n-1]
// weight: w[0..n-1] (gamma), may be nullptr
// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
// eps: small epsilon for numerical stability
FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
const float* weight, int n, float eps) {
if (n <= 0) return;
// ---------- Pass 1: compute sum of squares ----------
float sum_sq = 0.0f;
int i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
FP32Vec8 sq = x_vec * x_vec;
FP32Vec8::AliasReg ar;
ar.reg = sq.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
sum_sq += ar.values[j];
}
}
// Tail
for (; i < n; ++i) {
float v = input[i];
sum_sq += v * v;
}
float mean_sq = sum_sq / static_cast<float>(n);
float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
// ---------- Pass 2: scale (and apply weight if given) ----------
const float inv_rms_f = inv_rms;
i = 0;
if (weight) {
// with gamma
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
float wtmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
wtmp[j] = weight[i + j];
}
FP32Vec8 w_vec(wtmp);
FP32Vec8 scale_vec(inv_rms_f);
FP32Vec8 y = x_vec * scale_vec * w_vec;
y.save(output + i);
}
for (; i < n; ++i) {
output[i] = input[i] * inv_rms_f * weight[i];
}
} else {
// without gamma
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
FP32Vec8 scale_vec(inv_rms_f);
FP32Vec8 y = x_vec * scale_vec;
y.save(output + i);
}
for (; i < n; ++i) {
output[i] = input[i] * inv_rms_f;
}
}
}
// Prefetch data to cache for better memory access performance
FORCE_INLINE void prefetch(const void* addr) {
__builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
}
}; // namespace vec_op
#endif
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment