Commit b8f4de71 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add conversions

parent c98974ee
...@@ -10,7 +10,11 @@ namespace ck { ...@@ -10,7 +10,11 @@ namespace ck {
// Declare a template function for scaled conversion // Declare a template function for scaled conversion
template <typename Y, typename X> template <typename Y, typename X>
#if CK_USE_NATIVE_MX_SUPPORT || CK_USE_OCP_FP8
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x); __host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#else
__host__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#endif
// convert f8_ocp_t to fp32 // convert f8_ocp_t to fp32
template <> template <>
...@@ -200,27 +204,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp ...@@ -200,27 +204,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
return out.float_1x32; return out.float_1x32;
} }
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x) inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale));
...@@ -231,8 +221,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be ...@@ -231,8 +221,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale,
float x) float x)
#else
inline __host__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale));
...@@ -243,8 +237,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_ ...@@ -243,8 +237,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
// convert fp32x2 to fp8x2 // convert fp32x2 to fp8x2
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x) float2_t x)
#else
inline __host__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, float2_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale));
...@@ -254,8 +252,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>( ...@@ -254,8 +252,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
} }
// convert fp32x2 to bf8x2 // convert fp32x2 to bf8x2
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale, inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x) float2_t x)
#else
inline __host__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale));
...@@ -267,8 +270,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t ...@@ -267,8 +270,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
// convert fp32x16 to fp8x16 // convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x16_ocp_t inline __host__ __device__ f8x16_ocp_t
scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ f8x16_ocp_t scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale));
...@@ -280,8 +288,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) ...@@ -280,8 +288,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x16 to bf8x16 // convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x16_ocp_t inline __host__ __device__ bf8x16_ocp_t
scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ bf8x16_ocp_t scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale));
...@@ -293,8 +306,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) ...@@ -293,8 +306,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x32 to fp8x32 // convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x32_ocp_t inline __host__ __device__ f8x32_ocp_t
scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ f8x32_ocp_t scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale));
...@@ -306,8 +324,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) ...@@ -306,8 +324,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
// convert fp32x32 to bf8x32 // convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x32_ocp_t inline __host__ __device__ bf8x32_ocp_t
scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ bf8x32_ocp_t scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale));
...@@ -316,10 +339,36 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) ...@@ -316,10 +339,36 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif #endif
} }
// convert fp4 to fp32
template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
#else
inline __host__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
#endif
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32 // convert vector of 2 fp4 to vector of 2 fp32
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale, inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
f4x2_t x) f4x2_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale, f4x2_t x)
#endif
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
union union
...@@ -340,8 +389,12 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b ...@@ -340,8 +389,12 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
// convert vector of 32 fp4 to vector of 32 fp32 // convert vector of 32 fp4 to vector of 32 fp32
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale, inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale,
f4x32_t x) f4x32_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale, f4x32_t x)
#endif
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
union union
...@@ -573,7 +626,11 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m ...@@ -573,7 +626,11 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
// convert fp32 to fp4 // convert fp32 to fp4
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x) inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
#endif
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale)); return f4_convert_sr(x, type_convert<float>(scale));
...@@ -584,8 +641,12 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t sca ...@@ -584,8 +641,12 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t sca
// convert vector of 2 fp32 to vector of 2 fp4 // convert vector of 2 fp32 to vector of 2 fp4
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale, inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale,
float2_t x) float2_t x)
#else
inline __host__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale, float2_t x)
#endif
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale)); return f4_convert_sr(x, type_convert<float>(scale));
...@@ -596,8 +657,12 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bex ...@@ -596,8 +657,12 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bex
// convert vector of 32 fp32 to vector of 32 fp4 // convert vector of 32 fp32 to vector of 32 fp4
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale, inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x) float32_t x)
#else
inline __host__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale)); return f4_convert_sr(x, type_convert<float>(scale));
...@@ -615,10 +680,61 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_ ...@@ -615,10 +680,61 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
* @return The converted 32-bit float representation of the input. * @return The converted 32-bit float representation of the input.
*/ */
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x) inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x)
#else
inline __host__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x)
#endif
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<f6_t>(scale, x); return utils::to_float<f6_t>(scale, x);
#endif
}
template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale,
f6x32_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale, f6x32_t x)
#endif
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<f6_t>(scale, in.f6_array[i]); });
return out.float_vector;
#endif
} }
/** /**
...@@ -630,10 +746,61 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc ...@@ -630,10 +746,61 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
* @return The converted 32-bit float representation of the input. * @return The converted 32-bit float representation of the input.
*/ */
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x) inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x)
#else
inline __host__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x)
#endif
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(scale, x); return utils::to_float<bf6_t>(scale, x);
#endif
}
template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale,
bf6x32_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale, bf6x32_t x)
#endif
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<bf6_t>(scale, in.bf6_array[i]); });
return out.float_vector;
#endif
} }
/** /**
...@@ -648,7 +815,26 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s ...@@ -648,7 +815,26 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
* @return The converted 6-bit floating-point value (f6_t). * @return The converted 6-bit floating-point value (f6_t).
*/ */
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t scale, float x) inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#else
inline __host__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif
{ {
#if CK_USE_SR_F6_CONVERSION #if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale)); return f6_convert_sr(x, type_convert<float>(scale));
...@@ -669,7 +855,26 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca ...@@ -669,7 +855,26 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
* @return The converted 6-bit floating-point value (bf6_t). * @return The converted 6-bit floating-point value (bf6_t).
*/ */
template <> template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t scale, float x) inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
template <>
#if CK_USE_NATIVE_MX_SUPPORT
inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#else
inline __host__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif
{ {
#if CK_USE_SR_F6_CONVERSION #if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale)); return bf6_convert_sr(x, type_convert<float>(scale));
......
...@@ -1400,8 +1400,79 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x) ...@@ -1400,8 +1400,79 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
*/ */
inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f) inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type<f6_t>(x / scale); return utils::sat_convert_to_type<f6_t>(x / scale);
#endif
}
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x[0],
x[1],
x[2],
x[3],
x[4],
x[5],
x[6],
x[7],
x[8],
x[9],
x[10],
x[11],
x[12],
x[13],
x[14],
x[15]};
float16_t in2 = {x[16],
x[17],
x[18],
x[19],
x[20],
x[21],
x[22],
x[23],
x[24],
x[25],
x[26],
x[27],
x[28],
x[29],
x[30],
x[31]};
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
#endif
} }
/** /**
...@@ -1418,15 +1489,65 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) ...@@ -1418,15 +1489,65 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng); return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
});
return out.f6_vector;
#endif
} }
/** /**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type * @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t). * (f6_t).
* *
* Depending on the CK_USE_SR_F4_CONVERSION flag, * Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding * the conversion uses stochastic rounding
* or round-to-nearest-even. * or round-to-nearest-even.
* *
...@@ -1436,7 +1557,17 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) ...@@ -1436,7 +1557,17 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
template <> template <>
inline __host__ __device__ f6_t type_convert<f6_t, float>(float x) inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{ {
#if CK_USE_SR_F4_CONVERSION #if defined(__gfx950__)
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{
#if defined(__gfx950__)
return f6_convert_sr(x); return f6_convert_sr(x);
#else #else
return f6_convert_rne(x); return f6_convert_rne(x);
...@@ -1455,8 +1586,53 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x) ...@@ -1455,8 +1586,53 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, f6_t>(f6_t x) inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x); return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.f6_array[i]);
});
return out.float_vector;
#endif
} }
/** /**
...@@ -1471,8 +1647,79 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x) ...@@ -1471,8 +1647,79 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
*/ */
inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f) inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type<bf6_t>(x / scale); return utils::sat_convert_to_type<bf6_t>(x / scale);
#endif
}
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x[0],
x[1],
x[2],
x[3],
x[4],
x[5],
x[6],
x[7],
x[8],
x[9],
x[10],
x[11],
x[12],
x[13],
x[14],
x[15]};
float16_t in2 = {x[16],
x[17],
x[18],
x[19],
x[20],
x[21],
x[22],
x[23],
x[24],
x[25],
x[26],
x[27],
x[28],
x[29],
x[30],
x[31]};
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
#endif
} }
/** /**
...@@ -1490,14 +1737,64 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) ...@@ -1490,14 +1737,64 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng); return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
});
return out.bf6_vector;
#endif
} }
/** /**
* @brief Specializes float-to-bf6_t conversion. * @brief Specializes float-to-bf6_t conversion.
* *
* Uses stochastic rounding if CK_USE_SR_F4_CONVERSION is defined, * Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even. * otherwise uses round-to-nearest-even.
* *
* @param x Input float value to convert. * @param x Input float value to convert.
...@@ -1506,7 +1803,17 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) ...@@ -1506,7 +1803,17 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
template <> template <>
inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x) inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x); return bf6_convert_sr(x);
#else #else
return bf6_convert_rne(x); return bf6_convert_rne(x);
...@@ -1525,8 +1832,53 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x) ...@@ -1525,8 +1832,53 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x) inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x); return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.bf6_array[i]);
});
return out.float_vector;
#endif
} }
template <typename Y, typename X, std::size_t NumElems> template <typename Y, typename X, std::size_t NumElems>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment