Commit a8cd34d6 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add vector conversions

parent b5ac2abd
......@@ -524,6 +524,206 @@ inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
fp4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{}, tmp_values{};
// TODO: pack in a loop
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type<f4_t>(x[0] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[1] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[2] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[3] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[4] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[5] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[6] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[7] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[8] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[9] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[10] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[11] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[12] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[13] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[14] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[15] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[16] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[17] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[18] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[19] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[20] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[21] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[22] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[23] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[24] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[25] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[26] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[27] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[28] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[29] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[30] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[31] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
......@@ -542,6 +742,215 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0}, tmp_values{0};
union
{
float2_t floatx2_array[16];
float32_t floatx32_array;
} float_values{0};
// TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[0], rng, scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[1], rng, scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[2], rng, scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[3], rng, scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[4], rng, scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[5], rng, scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[6], rng, scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[7], rng, scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[8], rng, scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[9], rng, scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[10], rng, scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[11], rng, scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[12], rng, scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[13], rng, scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[14], rng, scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.floatx2_array[15], rng, scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
......@@ -553,16 +962,204 @@ inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t type_convert<f4x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t data)
inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
{
#if defined(__gfx950__)
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(data, scale, 0)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0)
.template AsType<float>()(Number<0>{});
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), data);
return utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
{
#if defined(__gfx950__)
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
#else
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x.unpack(1)),
utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(), x.unpack(0))};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
float2_t op;
float32_t ret;
float scale = 1.0f;
// TODO: pack in a loop
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[0], type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[1], type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[2], type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[3], type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[4], type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[5], type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[6], type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[7], type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[8], type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[9], type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[10], type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[11], type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[12], type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[13], type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[14], type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[15], type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[0].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[0].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[1].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[1].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[2].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[2].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[3].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[3].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[4].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[4].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[5].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[5].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[6].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[6].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[7].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[7].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[8].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[8].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[9].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[9].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[10].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[10].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[11].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[11].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[12].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[12].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[13].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[13].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[14].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[14].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[15].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(NumericLimits<e8m0_scale_t>::Binary_1(),
f4_values.f4x2_array[15].unpack(1));
return float_values.float32_array;
#endif
}
......@@ -587,13 +1184,147 @@ template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_scale_t scale, f4_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0)
.template AsType<float>()(Number<0>{});
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_scale_t scale,
f4x2_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
#else
float2_t ret{utils::to_float<f4_t>(scale, x.unpack(1)),
utils::to_float<f4_t>(scale, x.unpack(0))};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_scale_t scale,
f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
float2_t op;
float32_t ret;
// TODO: pack in a loop
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[0], type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[1], type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[2], type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[3], type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[4], type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[5], type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[6], type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[7], type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[8], type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[9], type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[10], type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[11], type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[12], type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[13], type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[14], type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[15], type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[0].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[0].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[1].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[1].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[2].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[2].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[3].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[3].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[4].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[4].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[5].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[5].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[6].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[6].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[7].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[7].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[8].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[8].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[9].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[9].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[10].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[10].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[11].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[11].unpack(1));
float_values.float_array[0] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[12].unpack(0));
float_values.float_array[1] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[12].unpack(1));
float_values.float_array[2] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[13].unpack(0));
float_values.float_array[3] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[13].unpack(1));
float_values.float_array[4] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[14].unpack(0));
float_values.float_array[5] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[14].unpack(1));
float_values.float_array[6] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[15].unpack(0));
float_values.float_array[7] = utils::to_float<f4_t>(scale, f4_values.f4x2_array[15].unpack(1));
return float_values.float32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t scale, float x)
......@@ -605,6 +1336,30 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_scale_t scale,
float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_scale_t scale,
float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment