Commit 807a4818 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Add constexpr where applicable.

parent b1a7d2a7
...@@ -111,7 +111,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -111,7 +111,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
fmax = bit_cast<_Float16>(ifmax); fmax = bit_cast<_Float16>(ifmax);
fmin = bit_cast<_Float16>(ifmin); fmin = bit_cast<_Float16>(ifmin);
} }
else if(is_float) else if constexpr(is_float)
{ {
const unsigned int ifInf = 0x7F800000; const unsigned int ifInf = 0x7F800000;
const unsigned int ifNegInf = 0xFF800000; const unsigned int ifNegInf = 0xFF800000;
...@@ -128,7 +128,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -128,7 +128,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
fmax = bit_cast<float>(ifmax); fmax = bit_cast<float>(ifmax);
fmin = bit_cast<float>(ifmin); fmin = bit_cast<float>(ifmin);
} }
else if(is_double) else if constexpr(is_double)
{ {
const unsigned long long ifInf = 0x7FF0000000000000ull; const unsigned long long ifInf = 0x7FF0000000000000ull;
const unsigned long long ifNegInf = 0xFFF0000000000000ull; const unsigned long long ifNegInf = 0xFFF0000000000000ull;
...@@ -167,7 +167,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -167,7 +167,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{ {
return fNeg0; return fNeg0;
} }
if(we == 4) if constexpr(we == 4)
{ // e4m3 { // e4m3
if((x & 0x7F) == 0x7F) if((x & 0x7F) == 0x7F)
{ {
...@@ -178,7 +178,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -178,7 +178,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{ // e5m2 { // e5m2
if((x & 0x3) == 0) if((x & 0x3) == 0)
{ {
if(clip) if constexpr(clip)
{ {
return sign ? fmin : fmax; return sign ? fmin : fmax;
} }
...@@ -194,7 +194,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -194,7 +194,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::
type>::type retval; type>::type retval;
if(we == 5 && is_half && !is_fnuz) if constexpr(we == 5 && is_half && !is_fnuz)
{ {
retval = x << 8; retval = x << 8;
return bit_cast<T>(retval); return bit_cast<T>(retval);
...@@ -228,10 +228,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -228,10 +228,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
if constexpr(sizeof(T) == 2) if constexpr(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa; retval = (sign << 15) | (exponent << 10) | mantissa;
else if(sizeof(T) == 4) else if constexpr(sizeof(T) == 4)
retval = (sign << 31) | (exponent << 23) | mantissa; retval = (sign << 31) | (exponent << 23) | mantissa;
else else
retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa; retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
return bit_cast<T>(retval); return bit_cast<T>(retval);
} }
...@@ -498,7 +499,7 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = ...@@ -498,7 +499,7 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
} }
} }
else if(interpret == CK_E4M3_OCP) else if constexpr(interpret == CK_E4M3_OCP)
{ // OCP type { // OCP type
if((val.i32val & 0x7F800000) != 0x7F800000) if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping { /// propagate NAN/INF, no clipping
...@@ -575,7 +576,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -575,7 +576,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
fInf = 0x7FF0000000000000ull; fInf = 0x7FF0000000000000ull;
mask = 0x7FFFFFFFFFFFFFFFull; mask = 0x7FFFFFFFFFFFFFFFull;
} }
else if(sizeof(T) == 4) else if constexpr(sizeof(T) == 4)
{ {
head = x & 0xFF800000; head = x & 0xFF800000;
mantissa = x & 0x7FFFFF; mantissa = x & 0x7FFFFF;
...@@ -604,7 +605,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -604,7 +605,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
} }
else else
{ {
if(we == 4) if constexpr(we == 4)
{ // e4m3 { // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
} }
...@@ -618,13 +619,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -618,13 +619,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
unsigned long long ifmax = 0; unsigned long long ifmax = 0;
if constexpr(sizeof(T) == 8) if constexpr(sizeof(T) == 8)
{ {
if(we == 5) if constexpr(we == 5)
{ // 57344 { // 57344
ifmax = 0x40EC000000000000ull; ifmax = 0x40EC000000000000ull;
} }
else else
{ {
if(is_fnuz) if constexpr(is_fnuz)
{ // 240 { // 240
ifmax = 0x406E000000000000ull; ifmax = 0x406E000000000000ull;
} }
...@@ -636,13 +637,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -636,13 +637,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
} }
else if(sizeof(T) == 4) else if(sizeof(T) == 4)
{ {
if(we == 5) if constexpr(we == 5)
{ {
ifmax = 0x47600000; ifmax = 0x47600000;
} }
else else
{ {
if(is_fnuz) if constexpr(is_fnuz)
{ {
ifmax = 0x43700000; ifmax = 0x43700000;
} }
...@@ -654,13 +655,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -654,13 +655,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
} }
else else
{ {
if(we == 5) if constexpr(we == 5)
{ {
ifmax = 0x7B00; ifmax = 0x7B00;
} }
else else
{ {
if(is_fnuz) if constexpr(is_fnuz)
{ {
ifmax = 0x5B80; ifmax = 0x5B80;
} }
...@@ -673,7 +674,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -673,7 +674,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
// Deal with inf and NaNs // Deal with inf and NaNs
if((x & fInf) == fInf) if((x & fInf) == fInf)
{ {
if(is_fnuz) if constexpr(is_fnuz)
return signed_inf; return signed_inf;
return mantissa != 0 ? nan : signed_inf; return mantissa != 0 ? nan : signed_inf;
...@@ -788,7 +789,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -788,7 +789,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
const int max_exp = (1 << we) - 1; const int max_exp = (1 << we) - 1;
if(f8_exponent > max_exp) if(f8_exponent > max_exp)
{ {
if(clip) if constexpr(clip)
{ {
mantissa = (1 << wm) - 1; mantissa = (1 << wm) - 1;
f8_exponent = max_exp; f8_exponent = max_exp;
...@@ -846,15 +847,15 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -846,15 +847,15 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{ {
return cast_to_f8<float, 3, 4, true, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float, 3, 4, true, sat == CK_SATFINITE, stochastic_rounding>(f, rng);
} }
else if(interp == CK_E5M2_FNUZ) else if constexpr(interp == CK_E5M2_FNUZ)
{ {
return cast_to_f8<float, 2, 5, true, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float, 2, 5, true, sat == CK_SATFINITE, stochastic_rounding>(f, rng);
} }
else if(interp == CK_E4M3_OCP) else if constexpr(interp == CK_E4M3_OCP)
{ {
return cast_to_f8<float, 3, 4, false, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float, 3, 4, false, sat == CK_SATFINITE, stochastic_rounding>(f, rng);
} }
else if(interp == CK_E5M2_OCP) else if constexpr(interp == CK_E5M2_OCP)
{ {
return cast_to_f8<float, 2, 5, false, sat == CK_SATFINITE, stochastic_rounding>(f, rng); return cast_to_f8<float, 2, 5, false, sat == CK_SATFINITE, stochastic_rounding>(f, rng);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment