"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "339e84ce86ee0cb50d1709dae1232ce4cda56398"
Unverified Commit 81516095 authored by ihb2032's avatar ihb2032 Committed by GitHub
Browse files

refactor(cpu_types_scalar.hpp): Unify scalar loop implementations using unroll_loop (#28847)


Signed-off-by: default avatarihb2032 <1355790728@qq.com>
Co-authored-by: default avatarlyd1992 <liuyudong@iscas.ac.cn>
parent fdf93486
...@@ -26,10 +26,6 @@ namespace vec_op { ...@@ -26,10 +26,6 @@ namespace vec_op {
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
#define __max(a, b) ((a) > (b) ? (a) : (b))
#define __min(a, b) ((a) < (b) ? (a) : (b))
#define __abs(a) ((a) < (0) ? (0 - a) : (a))
typedef struct f16x8_t { typedef struct f16x8_t {
uint16_t val[8]; uint16_t val[8];
} f16x8_t; } f16x8_t;
...@@ -99,7 +95,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> { ...@@ -99,7 +95,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const { void save(void* ptr, const int elem_num) const {
int num = __min(elem_num, VEC_ELEM_NUM); int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
} }
}; };
...@@ -128,7 +124,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> { ...@@ -128,7 +124,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
void save(void* ptr, const int elem_num) const { void save(void* ptr, const int elem_num) const {
int num = __min(elem_num, VEC_ELEM_NUM); int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
} }
}; };
...@@ -143,9 +139,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> { ...@@ -143,9 +139,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
explicit BF16Vec32(f16x32_t data) : reg(data) {}; explicit BF16Vec32(f16x32_t data) : reg(data) {};
explicit BF16Vec32(BF16Vec8& vec8_data) { explicit BF16Vec32(BF16Vec8& vec8_data) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&vec8_data, this](int i) {
reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM]; reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM];
} });
} }
void save(void* ptr) const { *reinterpret_cast<f16x32_t*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<f16x32_t*>(ptr) = reg; }
...@@ -157,15 +153,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> { ...@@ -157,15 +153,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
f32x4_t reg; f32x4_t reg;
explicit FP32Vec4(float v) { explicit FP32Vec4(float v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
reg.val[i] = v;
}
} }
explicit FP32Vec4() { explicit FP32Vec4() {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
reg.val[i] = 0.0f;
}
} }
explicit FP32Vec4(const float* ptr) explicit FP32Vec4(const float* ptr)
...@@ -182,15 +174,11 @@ struct FP32Vec8 : public Vec<FP32Vec8> { ...@@ -182,15 +174,11 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
f32x8_t reg; f32x8_t reg;
explicit FP32Vec8(float v) { explicit FP32Vec8(float v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
reg.val[i] = v;
}
} }
explicit FP32Vec8() { explicit FP32Vec8() {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
reg.val[i] = 0.0f;
}
} }
explicit FP32Vec8(const float* ptr) explicit FP32Vec8(const float* ptr)
...@@ -201,78 +189,68 @@ struct FP32Vec8 : public Vec<FP32Vec8> { ...@@ -201,78 +189,68 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) { explicit FP32Vec8(const FP16Vec8& v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
reg.val[i] = fp16_to_float(v.reg.val[i]); [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
} }
FP32Vec8(const BF16Vec8& v) { FP32Vec8(const BF16Vec8& v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
reg.val[i] = bf16_to_float(v.reg.val[i]); [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
} }
float reduce_sum() const { float reduce_sum() const {
float result = 0; float result = 0;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result += reg.val[i]; [&result, this](int i) { result += reg.val[i]; });
}
return result; return result;
} }
FP32Vec8 exp() const { FP32Vec8 exp() const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = expf(reg.val[i]); [&ret, this](int i) { ret.val[i] = expf(reg.val[i]); });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
FP32Vec8 tanh() const { FP32Vec8 tanh() const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = tanhf(reg.val[i]); [&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
FP32Vec8 er() const { FP32Vec8 er() const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = erf(reg.val[i]); [&ret, this](int i) { ret.val[i] = erf(reg.val[i]); });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
FP32Vec8 operator*(const FP32Vec8& b) const { FP32Vec8 operator*(const FP32Vec8& b) const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = reg.val[i] * b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
FP32Vec8 operator+(const FP32Vec8& b) const { FP32Vec8 operator+(const FP32Vec8& b) const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = reg.val[i] + b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
FP32Vec8 operator-(const FP32Vec8& b) const { FP32Vec8 operator-(const FP32Vec8& b) const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = reg.val[i] - b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
FP32Vec8 operator/(const FP32Vec8& b) const { FP32Vec8 operator/(const FP32Vec8& b) const {
f32x8_t ret; f32x8_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
ret.val[i] = reg.val[i] / b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
}
return FP32Vec8(ret); return FP32Vec8(ret);
} }
...@@ -284,15 +262,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -284,15 +262,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
f32x16_t reg; f32x16_t reg;
explicit FP32Vec16(float v) { explicit FP32Vec16(float v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
reg.val[i] = v;
}
} }
explicit FP32Vec16() { explicit FP32Vec16() {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
reg.val[i] = 0.0f;
}
} }
explicit FP32Vec16(const float* ptr) explicit FP32Vec16(const float* ptr)
...@@ -301,29 +275,27 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -301,29 +275,27 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(f32x16_t data) : reg(data) {}; explicit FP32Vec16(f32x16_t data) : reg(data) {};
FP32Vec16(const FP32Vec4& data) { FP32Vec16(const FP32Vec4& data) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM]; reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM];
} });
} }
FP32Vec16(const FP32Vec8& data) { FP32Vec16(const FP32Vec8& data) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM]; reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM];
} });
} }
FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}; FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v) { explicit FP32Vec16(const FP16Vec16& v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
reg.val[i] = fp16_to_float(v.reg.val[i]); [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
} }
explicit FP32Vec16(const BF16Vec16& v) { explicit FP32Vec16(const BF16Vec16& v) {
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
reg.val[i] = bf16_to_float(v.reg.val[i]); [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
} }
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
...@@ -331,82 +303,74 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -331,82 +303,74 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16 operator*(const FP32Vec16& b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result.reg.val[i] = reg.val[i] * b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
} return FP32Vec16(ret);
return result;
} }
FP32Vec16 operator+(const FP32Vec16& b) const { FP32Vec16 operator+(const FP32Vec16& b) const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result.reg.val[i] = reg.val[i] + b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
} return FP32Vec16(ret);
return result;
} }
FP32Vec16 operator-(const FP32Vec16& b) const { FP32Vec16 operator-(const FP32Vec16& b) const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result.reg.val[i] = reg.val[i] - b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
} return FP32Vec16(ret);
return result;
} }
FP32Vec16 operator/(const FP32Vec16& b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result.reg.val[i] = reg.val[i] / b.reg.val[i]; [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
} return FP32Vec16(ret);
return result;
} }
FP32Vec16 max(const FP32Vec16& b) const { FP32Vec16 max(const FP32Vec16& b) const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
result.reg.val[i] = __max(reg.val[i], b.reg.val[i]); ret.val[i] = std::max(reg.val[i], b.reg.val[i]);
} });
return result; return FP32Vec16(ret);
} }
FP32Vec16 min(const FP32Vec16& b) const { FP32Vec16 min(const FP32Vec16& b) const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
result.reg.val[i] = __min(reg.val[i], b.reg.val[i]); ret.val[i] = std::min(reg.val[i], b.reg.val[i]);
} });
return result; return FP32Vec16(ret);
} }
FP32Vec16 abs() const { FP32Vec16 abs() const {
FP32Vec16 result(0.0f); f32x16_t ret;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result.reg.val[i] = __abs(reg.val[i]); [&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); });
} return FP32Vec16(ret);
return result;
} }
float reduce_sum() const { float reduce_sum() const {
float result = 0.0f; float result = 0.0f;
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result += reg.val[i]; [&result, this](int i) { result += reg.val[i]; });
}
return result; return result;
} }
float reduce_max() const { float reduce_max() const {
float result = reg.val[0]; float result = std::numeric_limits<float>::lowest();
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result = __max(reg.val[i], result); [&result, this](int i) { result = std::max(reg.val[i], result); });
}
return result; return result;
} }
float reduce_min() const { float reduce_min() const {
float result = reg.val[0]; float result = std::numeric_limits<float>::max();
for (int i = 0; i < VEC_ELEM_NUM; ++i) { unroll_loop<int, VEC_ELEM_NUM>(
result = __min(reg.val[i], result); [&result, this](int i) { result = std::min(reg.val[i], result); });
}
return result; return result;
} }
...@@ -414,13 +378,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -414,13 +378,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_sub_sum(int idx) { float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
float sum = 0.0; float sum = 0.0;
int start = idx * group_size; const int start = idx * group_size;
int end = (idx + 1) * group_size; unroll_loop<int, group_size>(
[&sum, &start, this](int i) { sum += reg.val[start + i]; });
for (; (start < VEC_ELEM_NUM) && (start < end); ++start) {
sum += reg.val[start];
}
return sum; return sum;
} }
...@@ -477,17 +437,13 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) { ...@@ -477,17 +437,13 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
} }
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
int i = 0; unroll_loop<int, FP16Vec16::VEC_ELEM_NUM>(
for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) { [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
reg.val[i] = float_to_fp16(v.reg.val[i]);
}
} }
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
int i = 0; unroll_loop<int, FP16Vec8::VEC_ELEM_NUM>(
for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) { [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
reg.val[i] = float_to_fp16(v.reg.val[i]);
}
} }
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
...@@ -495,17 +451,13 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { ...@@ -495,17 +451,13 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
} }
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
int i = 0; unroll_loop<int, BF16Vec8::VEC_ELEM_NUM>(
for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) { [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
reg.val[i] = float_to_bf16(v.reg.val[i]);
}
} }
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
int i = 0; unroll_loop<int, BF16Vec16::VEC_ELEM_NUM>(
for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) { [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
reg.val[i] = float_to_bf16(v.reg.val[i]);
}
} }
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); } inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment