Unverified Commit acfb3392 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Update clipping for fp8/bf8 conversion (#1182)

* Update clipping for fp8 conversion

* Add clipping for bf8 conversion

* Format
parent a776978c
......@@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union
{
......@@ -121,6 +118,11 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
} val;
val.fval = x;
uint32_t ival = 0;
const float max_fp8 = 240.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
......@@ -168,6 +170,11 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
} val;
val.fval = x;
uint32_t ival = 0;
const float max_bf8 = 57344.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
......@@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union
{
......@@ -220,6 +224,11 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
} val;
val.fval = x;
uint32_t ival = 0;
const float max_fp8 = 240.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
......@@ -265,6 +274,11 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
} val;
val.fval = x;
uint32_t ival = 0;
const float max_bf8 = 57344.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment