Unverified Commit 04eac6ba authored by lyd1992's avatar lyd1992 Committed by GitHub
Browse files

[Bugfix][CPU][RISC-V] Clamp exp() input to prevent NaN (#40428)


Signed-off-by: default avatarliuyudong <liuyudong@iscas.ac.cn>
parent 9047288b
...@@ -15,16 +15,12 @@ ...@@ -15,16 +15,12 @@
#include <torch/all.h> #include <torch/all.h>
namespace vec_op { namespace vec_op {
#ifdef RISCV_BF16_SUPPORT // BFloat16 is always supported on RISC-V: natively when RISCV_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ // is defined, otherwise via the FP32-simulation fallback path.
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
#else AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
...@@ -486,9 +482,18 @@ struct FP32Vec8 : public Vec<FP32Vec8> { ...@@ -486,9 +482,18 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
} }
FP32Vec8 exp() const { FP32Vec8 exp() const {
// Clamp input to prevent NaN: exp(-inf) must return 0, not NaN.
// Without clamping, -inf * 0.0 = NaN in the final poly * scale step.
// Matches the clamping strategy used by x86 AVX-512 and ARM NEON.
constexpr float exp_lo = -87.3365447505f; // ln(FLT_MIN)
constexpr float exp_hi = 88.7228391117f; // ln(FLT_MAX)
fixed_fp32x8_t x = RVVI(__riscv_vfmin_vf_f32, LMUL_256)(
RVVI(__riscv_vfmax_vf_f32, LMUL_256)(reg, exp_lo, VEC_ELEM_NUM), exp_hi,
VEC_ELEM_NUM);
const float inv_ln2 = 1.44269504088896341f; const float inv_ln2 = 1.44269504088896341f;
fixed_fp32x8_t x_scaled = fixed_fp32x8_t x_scaled =
RVVI(__riscv_vfmul_vf_f32, LMUL_256)(reg, inv_ln2, VEC_ELEM_NUM); RVVI(__riscv_vfmul_vf_f32, LMUL_256)(x, inv_ln2, VEC_ELEM_NUM);
fixed_i32x8_t n_int = fixed_i32x8_t n_int =
RVVI(__riscv_vfcvt_x_f_v_i32, LMUL_256)(x_scaled, VEC_ELEM_NUM); RVVI(__riscv_vfcvt_x_f_v_i32, LMUL_256)(x_scaled, VEC_ELEM_NUM);
fixed_fp32x8_t n_float = fixed_fp32x8_t n_float =
...@@ -706,9 +711,18 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -706,9 +711,18 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
} }
FP32Vec16 exp() const { FP32Vec16 exp() const {
// Clamp input to prevent NaN: exp(-inf) must return 0, not NaN.
// Without clamping, -inf * 0.0 = NaN in the final poly * scale step.
// Matches the clamping strategy used by x86 AVX-512 and ARM NEON.
constexpr float exp_lo = -87.3365447505f; // ln(FLT_MIN)
constexpr float exp_hi = 88.7228391117f; // ln(FLT_MAX)
fixed_fp32x16_t x = RVVI(__riscv_vfmin_vf_f32, LMUL_512)(
RVVI(__riscv_vfmax_vf_f32, LMUL_512)(reg, exp_lo, VEC_ELEM_NUM), exp_hi,
VEC_ELEM_NUM);
const float inv_ln2 = 1.44269504088896341f; const float inv_ln2 = 1.44269504088896341f;
fixed_fp32x16_t x_scaled = fixed_fp32x16_t x_scaled =
RVVI(__riscv_vfmul_vf_f32, LMUL_512)(reg, inv_ln2, VEC_ELEM_NUM); RVVI(__riscv_vfmul_vf_f32, LMUL_512)(x, inv_ln2, VEC_ELEM_NUM);
fixed_i32x16_t n_int = fixed_i32x16_t n_int =
RVVI(__riscv_vfcvt_x_f_v_i32, LMUL_512)(x_scaled, VEC_ELEM_NUM); RVVI(__riscv_vfcvt_x_f_v_i32, LMUL_512)(x_scaled, VEC_ELEM_NUM);
fixed_fp32x16_t n_float = fixed_fp32x16_t n_float =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment