Unverified Commit 2f314bc5 authored by almayne's avatar almayne Committed by GitHub
Browse files

[CPU] Added faster exp routine for lower precision data types. (#38112)


Signed-off-by: default avatarAnna Mayne <anna.mayne@arm.com>
Co-authored-by: default avatarFadi Arafeh <fadi.arafeh@arm.com>
Co-authored-by: default avatarLi, Jiang <jiang1.li@intel.com>
parent 2196bac1
...@@ -61,8 +61,23 @@ ...@@ -61,8 +61,23 @@
#endif #endif
#ifdef __aarch64__ #ifdef __aarch64__
// Implementation copied from Arm Optimized Routines (expf AdvSIMD) // Implementation of neon_expf copied from Arm Optimized Routines (expf
// AdvSIMD)
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
//
// Additional fast exponential intended for cases where outputs will be
// downcasted to FP16 / BF16 (e.g. attention softmax). Accurate within 1 ULP
// for FP16 Accurate within 1 ULP for BF16 for inputs in [-87.683, 88.376] &
// clamps inputs outside this range to 0 / inf. Implementation is similar to
// exp_u20, but:
// - uses a third degree polynomial approximation for exp(r) instead of a
// fifth degree one, with coefficients re-tuned.
// - does not split natural log (ln) into high / low parts
// - clamps exp(x) to 0 for x < -87.683113f and inf for x > 88.3762589f
// exp(x) = 2^n (exp(r))
// r = x - n*ln2, with n = round(x/ln2)
// exp(r) ~ poly(r) = 1 + r + r^2 * (c3 + c2 * r)
// n = round(x / ln2), r = x - n*ln2
#include <limits> #include <limits>
#define DEFINE_FAST_EXP \ #define DEFINE_FAST_EXP \
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \ const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \
...@@ -106,6 +121,37 @@ ...@@ -106,6 +121,37 @@
result.val[2] = neon_expf(vec.reg.val[2]); \ result.val[2] = neon_expf(vec.reg.val[2]); \
result.val[3] = neon_expf(vec.reg.val[3]); \ result.val[3] = neon_expf(vec.reg.val[3]); \
return vec_op::FP32Vec16(result); \ return vec_op::FP32Vec16(result); \
}; \
const float32x4_t lower_bound = vdupq_n_f32(-0x1.5ebb82p+6f); \
const float32x4_t upper_bound = vdupq_n_f32(0x1.61814ap+6f); \
constexpr float ln2 = 0x1.62e43p-1f; \
constexpr float f_c2 = 0x1.5592ecp-3f; \
const float32x4_t f_c3 = vdupq_n_f32(0x1.017d34p-1f); \
auto neon_expf_f16 = [&](float32x4_t values) __attribute__(( \
always_inline)) { \
const uint32x4_t lt_lower = vcltq_f32(values, lower_bound); \
const uint32x4_t gt_upper = vcgtq_f32(values, upper_bound); \
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \
float32x4_t r = vfmsq_n_f32(values, n, ln2); \
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \
float32x4_t r2 = vmulq_f32(r, r); \
float32x4_t q = vfmaq_n_f32(f_c3, r, f_c2); \
float32x4_t s = vaddq_f32(vdupq_n_f32(1.0f), r); \
float32x4_t p = vfmaq_f32(s, q, r2); \
float32x4_t y = \
vreinterpretq_f32_u32(vaddq_u32(vreinterpretq_u32_f32(p), e)); \
y = vbslq_f32(lt_lower, vdupq_n_f32(0.0f), y); \
y = vbslq_f32(gt_upper, vdupq_n_f32(INFINITY), y); \
return y; \
}; \
auto fast_exp_f16 = [&](const vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf_f16(vec.reg.val[0]); \
result.val[1] = neon_expf_f16(vec.reg.val[1]); \
result.val[2] = neon_expf_f16(vec.reg.val[2]); \
result.val[3] = neon_expf_f16(vec.reg.val[3]); \
return vec_op::FP32Vec16(result); \
}; };
#endif // __aarch64__ #endif // __aarch64__
......
...@@ -1152,7 +1152,11 @@ class AttentionMainLoop { ...@@ -1152,7 +1152,11 @@ class AttentionMainLoop {
bool use_sink) { bool use_sink) {
#ifdef DEFINE_FAST_EXP #ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP DEFINE_FAST_EXP
bool constexpr IsReducedPrecision =
std::is_same_v<query_t, c10::BFloat16> ||
std::is_same_v<query_t, c10::Half>;
#endif #endif
using prob_buffer_vec_t = typename VecTypeTrait<prob_buffer_t>::vec_t; using prob_buffer_vec_t = typename VecTypeTrait<prob_buffer_t>::vec_t;
static_assert(sizeof(prob_buffer_t) <= sizeof(logits_buffer_t)); static_assert(sizeof(prob_buffer_t) <= sizeof(logits_buffer_t));
...@@ -1201,8 +1205,17 @@ class AttentionMainLoop { ...@@ -1201,8 +1205,17 @@ class AttentionMainLoop {
vec = vec - max_vec; vec = vec - max_vec;
// compute exp // compute exp
#ifdef DEFINE_FAST_EXP
#if defined(DEFINE_FAST_EXP)
#ifdef __aarch64__
if constexpr (IsReducedPrecision) {
vec = fast_exp_f16(vec);
} else
#endif
{
vec = fast_exp(vec); vec = fast_exp(vec);
}
prob_buffer_vec_t output_vec(vec); prob_buffer_vec_t output_vec(vec);
output_vec.save(curr_prob_buffer_iter); output_vec.save(curr_prob_buffer_iter);
#else #else
...@@ -1258,7 +1271,11 @@ class AttentionMainLoop { ...@@ -1258,7 +1271,11 @@ class AttentionMainLoop {
int32_t kv_tile_token_num, float softcap_scale) { int32_t kv_tile_token_num, float softcap_scale) {
#ifdef DEFINE_FAST_EXP #ifdef DEFINE_FAST_EXP
DEFINE_FAST_EXP DEFINE_FAST_EXP
bool constexpr IsReducedPrecision =
std::is_same_v<query_t, c10::BFloat16> ||
std::is_same_v<query_t, c10::Half>;
#endif #endif
float inv_softcap_scale = 1.0 / softcap_scale; float inv_softcap_scale = 1.0 / softcap_scale;
vec_op::FP32Vec16 softcap_scale_vec(softcap_scale); vec_op::FP32Vec16 softcap_scale_vec(softcap_scale);
vec_op::FP32Vec16 inv_softcap_scale_vec(inv_softcap_scale); vec_op::FP32Vec16 inv_softcap_scale_vec(inv_softcap_scale);
...@@ -1272,8 +1289,15 @@ class AttentionMainLoop { ...@@ -1272,8 +1289,15 @@ class AttentionMainLoop {
vec_op::FP32Vec16 vec(curr_logits_buffer_iter); vec_op::FP32Vec16 vec(curr_logits_buffer_iter);
vec = vec * inv_softcap_scale_vec; vec = vec * inv_softcap_scale_vec;
#ifdef DEFINE_FAST_EXP #if defined(DEFINE_FAST_EXP)
#ifdef __aarch64__
if constexpr (IsReducedPrecision) {
vec = fast_exp_f16(vec);
} else
#endif
{
vec = fast_exp(vec); vec = fast_exp(vec);
}
vec_op::FP32Vec16 inv_vec = ones_vec / vec; vec_op::FP32Vec16 inv_vec = ones_vec / vec;
vec = (vec - inv_vec) / (vec + inv_vec); vec = (vec - inv_vec) / (vec + inv_vec);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment