Unverified Commit 04eac6ba authored by lyd1992's avatar lyd1992 Committed by GitHub
Browse files

[Bugfix][CPU][RISC-V] Clamp exp() input to prevent NaN (#40428)


Signed-off-by: default avatarliuyudong <liuyudong@iscas.ac.cn>
parent 9047288b
......@@ -15,16 +15,12 @@
#include <torch/all.h>
namespace vec_op {
#ifdef RISCV_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
// BFloat16 is always supported on RISC-V: natively when RISCV_BF16_SUPPORT
// is defined, otherwise via the FP32-simulation fallback path.
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
......@@ -486,9 +482,18 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
FP32Vec8 exp() const {
// Clamp input to prevent NaN: exp(-inf) must return 0, not NaN.
// Without clamping, -inf * 0.0 = NaN in the final poly * scale step.
// Matches the clamping strategy used by x86 AVX-512 and ARM NEON.
constexpr float exp_lo = -87.3365447505f; // ln(FLT_MIN)
constexpr float exp_hi = 88.7228391117f; // ln(FLT_MAX)
fixed_fp32x8_t x = RVVI(__riscv_vfmin_vf_f32, LMUL_256)(
RVVI(__riscv_vfmax_vf_f32, LMUL_256)(reg, exp_lo, VEC_ELEM_NUM), exp_hi,
VEC_ELEM_NUM);
const float inv_ln2 = 1.44269504088896341f;
fixed_fp32x8_t x_scaled =
RVVI(__riscv_vfmul_vf_f32, LMUL_256)(reg, inv_ln2, VEC_ELEM_NUM);
RVVI(__riscv_vfmul_vf_f32, LMUL_256)(x, inv_ln2, VEC_ELEM_NUM);
fixed_i32x8_t n_int =
RVVI(__riscv_vfcvt_x_f_v_i32, LMUL_256)(x_scaled, VEC_ELEM_NUM);
fixed_fp32x8_t n_float =
......@@ -706,9 +711,18 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
}
FP32Vec16 exp() const {
// Clamp input to prevent NaN: exp(-inf) must return 0, not NaN.
// Without clamping, -inf * 0.0 = NaN in the final poly * scale step.
// Matches the clamping strategy used by x86 AVX-512 and ARM NEON.
constexpr float exp_lo = -87.3365447505f; // ln(FLT_MIN)
constexpr float exp_hi = 88.7228391117f; // ln(FLT_MAX)
fixed_fp32x16_t x = RVVI(__riscv_vfmin_vf_f32, LMUL_512)(
RVVI(__riscv_vfmax_vf_f32, LMUL_512)(reg, exp_lo, VEC_ELEM_NUM), exp_hi,
VEC_ELEM_NUM);
const float inv_ln2 = 1.44269504088896341f;
fixed_fp32x16_t x_scaled =
RVVI(__riscv_vfmul_vf_f32, LMUL_512)(reg, inv_ln2, VEC_ELEM_NUM);
RVVI(__riscv_vfmul_vf_f32, LMUL_512)(x, inv_ln2, VEC_ELEM_NUM);
fixed_i32x16_t n_int =
RVVI(__riscv_vfcvt_x_f_v_i32, LMUL_512)(x_scaled, VEC_ELEM_NUM);
fixed_fp32x16_t n_float =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment