Commit e137941c authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Fix gfx950 conversions

parent 0ad5d7f7
...@@ -16,6 +16,10 @@ namespace ck { ...@@ -16,6 +16,10 @@ namespace ck {
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx90a__)
#define __gfx950__
#endif
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
...@@ -512,12 +516,7 @@ inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f) ...@@ -512,12 +516,7 @@ inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
uint32_t bitwise; uint32_t bitwise;
f4_t f4_array[4]; f4_t f4_array[4];
} value{0}; } value{0};
value.bitwise = value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0);
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise,
in.template AsType<float>()(Number<0>{}),
in.template AsType<float>()(Number<1>{}),
scale,
0);
return value.f4_array[0]; return value.f4_array[0];
#else #else
return utils::sat_convert_to_type<f4_t>(x / scale); return utils::sat_convert_to_type<f4_t>(x / scale);
...@@ -531,7 +530,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) ...@@ -531,7 +530,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
union union
{ {
uint32_t bitwise; uint32_t bitwise;
fp4x2_t f4x2_array[4]; f4x2_t f4x2_array[4];
} value{0}; } value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0]; return value.f4x2_array[0];
...@@ -735,7 +734,14 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) ...@@ -735,7 +734,14 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
uint32_t bitwise; uint32_t bitwise;
f4_t f4_array[4]; f4_t f4_array[4];
} value{0}; } value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0); union
{
float float_array[2];
float2_t float2_array;
} float_values{{x}};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0]; return value.f4_array[0];
#else #else
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng); return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
...@@ -787,55 +793,55 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f ...@@ -787,55 +793,55 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
} float_values{0}; } float_values{0};
// TODO: pack in a loop // TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[0], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[1], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[2], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[3], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[4], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[5], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[6], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[7], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[8], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[9], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[10], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[11], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[12], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[13], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[14], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[15], rng, scale, 0); tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array; return f4_values.f4x32_array;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment