Commit 5f1a24a8 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add device conversions

parent 1bca7134
......@@ -504,52 +504,41 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
}
// convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ f4_t f4_convert_rne(float x)
inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
value.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise,
in.template AsType<float>()(Number<0>{}),
in.template AsType<float>()(Number<1>{}),
scale,
0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type<f4_t>(x);
return utils::sat_convert_to_type<f4_t>(x / scale);
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x)
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type_sr<f4_t>(x, rng);
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
......@@ -568,12 +557,10 @@ inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t data)
{
#if defined(__gfx94__)
// float fval;
// uint32_t i32val = static_cast<uint32_t>(x);
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// return fval;
#if defined(__gfx950__)
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(data, scale, 0)
.template AsType<float>()(Number<0>{});
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), data);
#endif
......@@ -585,16 +572,24 @@ __host__ __device__ constexpr Y scaled_type_convert(e8m0_scale_t scale, X x);
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t data)
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t x)
{
#if defined(__gfx94__)
// float fval;
// uint32_t i32val = static_cast<uint32_t>(x);
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// return fval;
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0)
.template AsType<float>()(Number<0>{});
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t scale, float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return utils::to_float<f4_t>(scale, data);
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment