"docs/source/en/api/pipelines/overview.md" did not exist on "df80ccf7de4cd7409141fe881fd4d630cd69fc4c"
Commit 5f1a24a8 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add device conversions

parent 1bca7134
...@@ -504,52 +504,41 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) ...@@ -504,52 +504,41 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
} }
// convert fp32 to fp4 with rounding to nearest even // convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ f4_t f4_convert_rne(float x) inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
{ {
#if defined(__gfx94__) #if defined(__gfx950__)
// union union
// { {
// float fval; uint32_t bitwise;
// uint32_t i32val; f4_t f4_array[4];
// uint8_t i8val[4]; // not endian independent } value{0};
// } val; value.bitwise =
// val.fval = x; __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise,
// uint32_t ival = 0; in.template AsType<float>()(Number<0>{}),
// const float max_fp8 = 240.0f; in.template AsType<float>()(Number<1>{}),
// // if x is not +/- infinity or nan scale,
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf) 0);
// // clip float value return value.f4_array[0];
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#else #else
return utils::sat_convert_to_type<f4_t>(x); return utils::sat_convert_to_type<f4_t>(x / scale);
#endif #endif
} }
// convert fp32 to fp4 with stochastic rounding // convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x) inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__) #if defined(__gfx950__)
// union union
// { {
// float fval; uint32_t bitwise;
// uint32_t i32val; f4_t f4_array[4];
// uint8_t i8val[4]; // not endian independent } value{0};
// } val; value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
// val.fval = x; return value.f4_array[0];
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#else #else
return utils::sat_convert_to_type_sr<f4_t>(x, rng); return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif #endif
} }
...@@ -568,12 +557,10 @@ inline __host__ __device__ f4_t type_convert<f4_t, float>(float x) ...@@ -568,12 +557,10 @@ inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t data) inline __host__ __device__ float type_convert<float, f4_t>(f4_t data)
{ {
#if defined(__gfx94__) #if defined(__gfx950__)
// float fval; float scale = 1.0f;
// uint32_t i32val = static_cast<uint32_t>(x); return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(data, scale, 0)
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); .template AsType<float>()(Number<0>{});
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// return fval;
#else #else
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), data); return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), data);
#endif #endif
...@@ -585,16 +572,24 @@ __host__ __device__ constexpr Y scaled_type_convert(e8m0_scale_t scale, X x); ...@@ -585,16 +572,24 @@ __host__ __device__ constexpr Y scaled_type_convert(e8m0_scale_t scale, X x);
// convert fp4 to fp32 // convert fp4 to fp32
template <> template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t data) inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx950__)
// float fval; return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0)
// uint32_t i32val = static_cast<uint32_t>(x); .template AsType<float>()(Number<0>{});
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); #else
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return utils::to_float<f4_t>(scale, x);
// return fval; #endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t scale, float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else #else
return utils::to_float<f4_t>(scale, data); return f4_convert_rne(x, type_convert<float>(scale));
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment