Unverified Commit 71971e33 authored by Jacket's avatar Jacket Committed by GitHub
Browse files

Fix exp2f_rcp to properly handle nan and 0xFE cases (#2647)


Signed-off-by: default avatarKaining Zhong <kainingz@nvidia.com>
parent 59f6f387
...@@ -425,10 +425,14 @@ inline fp8e8m0 float_to_e8m0(float val) { ...@@ -425,10 +425,14 @@ inline fp8e8m0 float_to_e8m0(float val) {
} }
inline float exp2f_rcp(fp8e8m0 biased_exp) { inline float exp2f_rcp(fp8e8m0 biased_exp) {
if (biased_exp == 0) { int32_t int_val = 0;
return 1.0f; if (biased_exp == 255) {
int_val = 0x7fffffff;
} else if (biased_exp == 254) {
int_val = 0x00400000;
} else {
int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
} }
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val); float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val; return fp32_val;
} }
......
...@@ -328,9 +328,13 @@ constexpr uint32_t FP32_MANTISSA_BITS = 23; ...@@ -328,9 +328,13 @@ constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_EXPONENT_BIAS = 127;
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { __device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 // Handle the special case of NaN.
: __int_as_float((254 - biased_exp) if (biased_exp == 255) return __int_as_float(0x7fffffff);
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) // Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of
// the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left.
if (biased_exp == 254) return __int_as_float(0x00400000);
// Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal.
return __int_as_float((254 - biased_exp) << FP32_MANTISSA_BITS);
} }
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) { __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment