Unverified Commit 3c5717df authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 171b9030 d9f1ead3
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float); ...@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#endif #endif
// math functions for the host, some are implemented by calling C++ std functions // math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); }; static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
...@@ -459,7 +459,7 @@ inline __host__ double expm1<double>(double x) ...@@ -459,7 +459,7 @@ inline __host__ double expm1<double>(double x)
{ {
return std::expm1(x); return std::expm1(x);
} }
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); }; static inline __device__ float abs(float x) { return ::abs(x); };
...@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x) ...@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); }; static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x) ...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <> template <>
inline __device__ half_t neg<half_t>(half_t x) inline __device__ half_t neg<half_t>(half_t x)
{ {
return __hneg(x); return __hneg(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
return false;
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return result == 0b0;
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f4_t>(scale, data))
return 0.0f;
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
f6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
bf6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template <>
__host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
f6_t const data [[maybe_unused]])
{
// no inf representation for fp6
return false;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template <>
__host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
bf6_t const data [[maybe_unused]])
{
// no inf representation for bf6
return false;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f6_t>(scale, data))
return 0.0f;
f6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<bf6_t>(scale, data))
return 0.0f;
bf6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<bf6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type<f6_t>(value);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type<bf6_t>(value);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type_sr<f6_t>(value, seed);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type_sr<bf6_t>(value, seed);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
namespace fp8_impl {
#if CK_MX_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
unsigned int rng = 0,
float scale = 1.0f)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
} val;
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
fp8_storage_t v4i8[4];
} ret{};
// unsigned int ival = 0;
val.fval = v;
if constexpr(stochastic_rounding)
{
ret.ival =
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
: __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
i8data = ret.v4i8[0];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
i8data = ret.v4i8[0];
}
return i8data;
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
unsigned int rng = 0,
float scale = 1.0f)
{
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
} ret{};
if constexpr(stochastic_rounding)
{
fp8x2_storage_t f8x2;
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
else
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
return f8x2;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
return ret.v2f8x2(Number<0>{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
} // namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck::utils {
union cvt
{
float value_float;
uint32_t value_bitwise;
};
template <typename DTYPE>
inline bool getDataHasInf()
{
return DTYPE::dataInfo.hasInf;
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
}
template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
return get_exponent_value<T>(x) == 0;
}
template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
for(uint i = 0; i < NumericUtils<T>::mant; i++)
{
mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
x >>= 1;
}
return mantissa;
}
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp;
if(is_subnormal<T>(data))
d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
else
d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
float d_mant = get_mantissa_value<T>(data);
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T>
inline T convert_to_type(float value)
{
using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint32_t max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
bitwise_type sign =
t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
bitwise_type exp =
((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant) | mantissa;
}
else
{
if(!get_data_has_inf<T>())
{
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant);
}
}
}
const int mfmt = NumericUtils<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & NumericUtils<float>::head_mask;
mantissa = x & NumericUtils<float>::mant_mask;
exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
sign = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
bias = NumericUtils<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = NumericUtils<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return 0;
else
return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - NumericUtils<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return 0;
}
mantissa &= (1UL << NumericUtils<T>::mant) - 1;
return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(out_exponent << NumericUtils<T>::mant) | mantissa;
}
template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
T sign = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
double d_max_value = static_cast<double>(max_value);
double d_actual_max = static_cast<double>(actual_max);
double d_value = static_cast<double>(value);
double d_is = std::abs(d_max_value - d_actual_max);
double d_seed = static_cast<double>(seed);
double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
double thresh = UINT_MAX * d_prob;
if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t>::data_max_negative_normal_mask;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
else
{
if(!get_data_has_inf<T>())
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
}
uint32_t f32 = bit_cast<uint32_t>(value);
auto f32_mant = f32 & NumericUtils<float>::mant_mask;
auto head = f32 & NumericUtils<float>::head_mask;
auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp;
auto mant = f32_mant;
bool subnorm = false;
if(f32 == 0)
return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min)
{
mant = f32_mant;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else if(exp < NumericUtils<T>::unbiased_exp_min &&
NumericUtils<T>::exp < NumericUtils<float>::exp)
{
subnorm = true;
auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32)
{
mant = 0;
f32_mant = 0;
}
else
{
f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff;
}
exp = 0;
mant = f32_mant;
}
uint32_t sr_shift = NumericUtils<T>::sr_shift;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding
if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm)
biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <ck/utility/ignore.hpp>
#include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
using uint8_t = unsigned char;
using uint16_t = unsigned short;
using uint32_t = unsigned int;
#endif
namespace ck { namespace ck {
// Pseudo random number generator // Pseudo random number generator
// version for fp32 // version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false> template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint32_t x = *(reinterpret_cast<uint32_t*>(&val)); uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
...@@ -23,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -23,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// version for fp16 // version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false> template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
...@@ -40,12 +47,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -40,12 +47,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32 // return 0 if data is not fp16 or fp32
template <typename T, template <typename T,
uint32_t seed_t, uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false> ck::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{ {
std::ignore = id; ck::ignore = id;
std::ignore = val; ck::ignore = val;
std::ignore = seed; ck::ignore = seed;
return 0; return 0;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace ck {
// Declare a template function for scaled conversion
template <typename Y, typename X>
#if CK_USE_OCP_FP8
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#else
__host__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#endif
// convert f8_ocp_t to fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float scaled_type_convert<float, f8_ocp_t>(e8m0_bexp_t scale, f8_ocp_t x)
#else
inline __host__ float scaled_type_convert<float, f8_ocp_t>(e8m0_bexp_t scale, f8_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.data);
#else
return type_convert<float>(scale) * type_convert<float>(x);
#endif
}
// convert bf8_ocp_t to fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float scaled_type_convert<float, bf8_ocp_t>(e8m0_bexp_t scale,
bf8_ocp_t x)
#else
inline __host__ float scaled_type_convert<float, bf8_ocp_t>(e8m0_bexp_t scale, bf8_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.data);
#else
return type_convert<float>(scale) * type_convert<float>(x);
#endif
}
// convert 2 x f8_ocp_t to 2 x fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t scale,
f8x2_ocp_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t scale, f8x2_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<0>{}]),
scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<1>{}])};
#endif
}
// convert 2 x bf8_ocp_t to 2 x fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t scale,
bf8x2_ocp_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t scale,
bf8x2_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<0>{}]),
scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<1>{}])};
#endif
}
// convert 16 x f8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float16_t scaled_type_convert<float16_t, f8x16_ocp_t>(e8m0_bexp_t scale,
f8x16_ocp_t x)
#else
inline __host__ float16_t scaled_type_convert<float16_t, f8x16_ocp_t>(e8m0_bexp_t scale,
f8x16_ocp_t x)
#endif
{
union
{
f8x16_ocp_t f8_1x16;
f8x2_ocp_t f8_2x8[8];
} in{x};
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}([&](auto i) {
out.float_2x8[i] = scaled_type_convert<float2_t, f8x2_ocp_t>(scale, in.f8_2x8[i]);
});
return out.float_1x16;
}
// convert 16 x bf8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float16_t scaled_type_convert<float16_t, bf8x16_ocp_t>(e8m0_bexp_t scale,
bf8x16_ocp_t x)
#else
inline __host__ float16_t scaled_type_convert<float16_t, bf8x16_ocp_t>(e8m0_bexp_t scale,
bf8x16_ocp_t x)
#endif
{
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} in{x};
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}([&](auto i) {
out.float_2x8[i] = scaled_type_convert<float2_t, bf8x2_ocp_t>(scale, in.bf8_2x8[i]);
});
return out.float_1x16;
}
// convert 32 x f8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f8x32_ocp_t>(e8m0_bexp_t scale,
f8x32_ocp_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, f8x32_ocp_t>(e8m0_bexp_t scale,
f8x32_ocp_t x)
#endif
{
union
{
f8x32_ocp_t f8_1x32;
f8x16_ocp_t f8_16x2[2];
} in{x};
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}([&](auto i) {
out.float_16x2[i] = scaled_type_convert<float16_t, f8x16_ocp_t>(scale, in.f8_16x2[i]);
});
return out.float_1x32;
}
// convert 32 x bf8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp_t scale,
bf8x32_ocp_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp_t scale,
bf8x32_ocp_t x)
#endif
{
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} in{x};
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}([&](auto i) {
out.float_16x2[i] = scaled_type_convert<float16_t, bf8x16_ocp_t>(scale, in.bf8_16x2[i]);
});
return out.float_1x32;
}
// convert fp32 to fp8
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32 to bf8
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale,
float x)
#else
inline __host__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x2 to fp8x2
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#else
inline __host__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, float2_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x2_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x2 to bf8x2
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#else
inline __host__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x2_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x16_ocp_t
scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ f8x16_ocp_t scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x16_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x16_ocp_t
scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ bf8x16_ocp_t scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x16_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x32_ocp_t
scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ f8x32_ocp_t scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x32_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x32_ocp_t
scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ bf8x32_ocp_t scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x32_ocp_t>(x, type_convert<float>(scale));
#endif
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
f4x2_t x)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale,
f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a 6-bit floating-point value (f6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The f6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template <>
inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x)
{
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<f6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale,
f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<f6_t>(scale, in.f6_array[i]); });
return out.float_vector;
#endif
}
/**
* @brief Converts a 6-bit floating-point value (bf6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The bf6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template <>
inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x)
{
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale,
bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<bf6_t>(scale, in.bf6_array[i]); });
return out.float_vector;
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (f6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (f6_t).
*/
template <>
inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template <>
inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (bf6_t).
*/
template <>
inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template <>
inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef CK_CODE_GEN_RTC
#include <ostream> #include <ostream>
#endif
#include "ck/utility/integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
...@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type; ...@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
#ifndef CK_CODE_GEN_RTC
template <ck::index_t... Is> template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>) std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{ {
...@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>) ...@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os << S::At(S::Size() - ck::Number<1>{}).value << "}"; os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os; return os;
} }
#endif
...@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector ...@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false> typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const __host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
...@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector ...@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false> typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x) __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
...@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x) ...@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if // is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize> // using MultiIndex<NSize>
// TODO: how to fix this? // TODO: how to fix this?
template < template <typename... Ys,
typename... Ys, typename X,
typename X, enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) ...@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template < template <typename... Ys,
typename... Ys, typename X,
typename X, enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) ...@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template < template <typename... Xs,
typename... Xs, typename Y,
typename Y, enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) ...@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template < template <typename... Xs,
typename... Xs, typename Y,
typename Y, enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) ...@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template < template <typename... Xs,
typename... Xs, typename Y,
typename Y, enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
// MultiIndex = scalar * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false> enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x) __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
{ {
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
...@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x) ...@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
// MultiIndex = MultiIndex * scalar // MultiIndex = MultiIndex * scalar
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false> enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a) __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
{ {
return a * x; return a * x;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -32,7 +32,7 @@ struct TupleElementKeyData ...@@ -32,7 +32,7 @@ struct TupleElementKeyData
template <typename T, template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value, typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
{ {
} }
...@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x) ...@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return std::forward(x.mData); return ck::forward(x.mData);
} }
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
...@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I ...@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!is_same<remove_cvref_t<Y>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size"); "wrong! inconsistent size");
...@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename Y, template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value, typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type = typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
{ {
} }
...@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type; ...@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
} }
// https://en.cppreference.com/w/cpp/utility/tuple/tie // https://en.cppreference.com/w/cpp/utility/tuple/tie
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp" #include "is_detected.hpp"
#endif
namespace ck { namespace ck {
...@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& ...@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const Tuple<Y&...>& ty) const Tuple<Y&...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -38,7 +40,7 @@ template <typename... X, typename... Y> ...@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple) ...@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
} }
} }
#ifndef CK_CODE_GEN_RTC
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(ck::declval<T&>().IsTuple());
#endif
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&) __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{ {
#ifndef CK_CODE_GEN_RTC
return (is_detected<is_tuple, Ts>::value || ...); return (is_detected<is_tuple, Ts>::value || ...);
#endif
} }
template <index_t depth = 0, typename T> template <index_t depth = 0, typename T>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp" #include "ck/utility/enable_if.hpp"
#include "ck/utility/enable_if.hpp" #include "ck/utility/integral_constant.hpp"
namespace ck { namespace ck {
#ifdef CK_CODE_GEN_RTC
template <typename X, typename Y> // NOLINTNEXTLINE
struct is_same : public integral_constant<bool, false> #define CK_BUILTIN_TYPE_TRAIT1(name) \
{ template <class T> \
}; struct name : bool_constant<__##name(T)> \
{ \
template <typename X> }
struct is_same<X, X> : public integral_constant<bool, true>
{ // NOLINTNEXTLINE
}; #define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
template <typename X, typename Y> struct name : bool_constant<__##name(T, U)> \
inline constexpr bool is_same_v = is_same<X, Y>::value; { \
}
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type; // NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <typename T> template <class... Ts> \
using remove_cv_t = typename std::remove_cv<T>::type; struct name : bool_constant<__##name(Ts...)> \
{ \
template <typename T> }
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
CK_BUILTIN_TYPE_TRAIT1(is_class);
template <typename T> CK_BUILTIN_TYPE_TRAIT1(is_pointer);
using remove_pointer_t = typename std::remove_pointer<T>::type; CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
template <typename T> CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <class T>
__host__ __device__ constexpr Y bit_cast(const X& x) struct remove_cv
{ {
static_assert(__has_builtin(__builtin_bit_cast), ""); using type = T;
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); };
return __builtin_bit_cast(Y, x); template <class T>
} struct remove_cv<const T> : remove_cv<T>
{
} // namespace ck };
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
typedef T type;
};
template <class T>
struct remove_pointer<T*>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* volatile>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const volatile>
{
typedef T type;
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <class T>
struct is_const : public integral_constant<bool, false>
{
};
template <class T>
struct is_const<const T> : public integral_constant<bool, true>
{
};
template <class T>
inline constexpr bool is_const_v = is_const<T>::value;
template <typename T>
inline constexpr bool is_reference_v = is_reference<T>::value;
template <class T>
struct remove_const
{
typedef T type;
};
template <class T>
struct remove_const<const T>
{
typedef T type;
};
template <class T>
using remove_const_t = typename remove_const<T>::type;
template <class T>
inline constexpr bool is_class_v = is_class<T>::value;
template <class T>
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
// template <typename T>
// T&& declval() noexcept;
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class...>
using void_t = void;
#else
#include <utility>
#include <type_traits>
using std::declval;
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_class_v;
using std::is_const_v;
using std::is_pointer;
using std::is_reference;
using std::is_reference_v;
using std::is_trivially_copyable;
using std::is_trivially_copyable_v;
using std::is_unsigned;
using std::remove_const_t;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::void_t;
#endif
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_floating_point : public integral_constant<bool, false>
{
};
template <>
struct is_floating_point<float> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<double> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<long double> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_integral : public integral_constant<bool, false>
{
};
template <>
struct is_integral<int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<signed char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<wchar_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char16_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char32_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<bool> : public integral_constant<bool, true>
{
};
template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename X, typename Y>
inline constexpr bool is_base_of_v = is_base_of<X, Y>::value;
template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename remove_pointer<T>::type;
template <typename T>
inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp" #include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp" #include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace ck { namespace ck {
// Define the common macro for gfx94x models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
namespace {
namespace details {
[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
half2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
return amd_assembly_pk_add_f16(x, y);
}
} // namespace details
} // namespace
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
// Nan check
if(x != x)
{
return uint16_t(0x7FC0);
}
union
{
float fp32;
uint32_t int32;
} u = {x};
const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false> ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -28,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -28,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type. // Convert X to Y, either X or Y is a const data type.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false> ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>; using NonConstY = ck::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>; using NonConstX = ck::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x)); return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
} }
...@@ -51,17 +110,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -51,17 +110,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32; return u.fp32;
} }
// convert fp32 to bfp16 // convert fp32 to bfp16, round to nearest even
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
union #if CK_USE_RNE_BF16_CONVERSION
{ return bf16_convert_rtn<bhalf_t>(x);
float fp32; #else
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
#endif
} }
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
...@@ -100,11 +157,23 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -100,11 +157,23 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -163,10 +232,14 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -163,10 +232,14 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -189,36 +262,47 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -189,36 +262,47 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: #else
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
x, rng); #endif
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with stochastic rounding // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -240,28 +324,37 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) ...@@ -240,28 +324,37 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with stochastic rounding // convert fp16 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x)); return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: #else
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
x, rng); #endif
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
...@@ -271,7 +364,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -271,7 +364,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even // convert fp32 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -296,32 +389,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) ...@@ -296,32 +389,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with rounding to nearest even // convert fp16 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x)); return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with rounding to nearest even // convert fp32 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -345,44 +440,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) ...@@ -345,44 +440,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with rounding to nearest even // convert fp16 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x)); return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_fnuz_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif #endif
} }
// convert fp8 to fp32 // convert fp8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -392,30 +502,95 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) ...@@ -392,30 +502,95 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif #endif
} }
template <> template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x); const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) = f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}]; return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif #endif
} }
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t res = {x_h, x_l};
#elif
float2_t res = {x_l, x_h};
#endif
return res;
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#else
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t res = {type_convert<bhalf_t>(x_h), type_convert<bhalf_t>(x_l)};
#else
bhalf2_t res = {type_convert<bhalf_t>(x_l), type_convert<bhalf_t>(x_h)};
#endif
return res;
}
template <> template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
...@@ -428,42 +603,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) ...@@ -428,42 +603,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_ocp_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_ocp_t>(x);
#endif #endif
} }
// convert fp8 to fp16 // convert fp8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif #endif
} }
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_ocp_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_ocp_t>(x);
#endif #endif
} }
// convert bf8 to fp32 // convert bf8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -473,107 +670,1336 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) ...@@ -473,107 +670,1336 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif #endif
} }
// convert fp16 to bf8 // convert fp16 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_fnuz_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif #endif
} }
// convert bf8 to fp16 // convert bf8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif #endif
} }
template <typename Y, typename X, std::size_t NumElems> // convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
const std::array<X, NumElems>& x)
{ {
for(std::size_t i = 0; i < NumElems; i++) #if defined(__gfx950__)
union
{ {
y[i] = type_convert<Y>(x[i]); uint32_t bitwise;
} f4_t f4_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type<f4_t>(x / scale);
#endif
} }
template <typename Y, typename X, index_t NumElems> // convert vector of 2 fp32 to vector of 2 fp4 with rne
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x) inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
{ {
for(std::size_t i = 0; i < NumElems; i++) #if defined(__gfx950__)
union
{ {
y[i] = type_convert<Y>(x[i]); uint32_t bitwise;
} f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
} }
// Declare a template function for bf16 conversion using RTN // convert vector of 32 fp32 to vector of 32 fp4 with rne
template <typename Y, typename X> inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f)
__host__ __device__ constexpr Y bf16_convert_rtn(X x); {
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{}, tmp_values{};
// TODO: pack in a loop
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
// Convert fp32 to bf16 with RTN if higher precision is needed tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type<f4_t>(x[0] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[1] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[2] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[3] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[4] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[5] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[6] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[7] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[8] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[9] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[10] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[11] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[12] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[13] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[14] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[15] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[16] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[17] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[18] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[19] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[20] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[21] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[22] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[23] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[24] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[25] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[26] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[27] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[28] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[29] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[30] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[31] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
union
{
float float_array[2];
float2_t float2_array;
} float_values{{x}};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0}, tmp_values{0};
union
{
float2_t floatx2_array[16];
float32_t floatx32_array;
} float_values{{0}};
// TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4
template <> template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x) inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t type_convert<f4x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
{ {
#if defined(__gfx950__)
union union
{ {
float fp32; float float_array[2];
uint32_t int32; float2_t float2_array;
} u = {x}; } float_values{};
float scale = 1.0f;
float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
// When the exponent bits are not all 1s, then the value is zero, normal, // convert vector of 2 fp4 to vector of 2 fp32
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus template <>
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
// This causes the bfloat16's mantissa to be incremented by 1 if the 16 {
// least significant bits of the float mantissa are greater than 0x8000, #if defined(__gfx950__)
// or if they are equal to 0x8000 and the least significant bit of the union
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when {
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already uint32_t bitwise;
// has the value 0x7f, then incrementing it causes it to become 0x00 and f4x2_t f4x2_array[4];
// the exponent is incremented by one, which is the next higher FP value } value{};
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal value.f4x2_array[0] = x;
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up float scale = 1.0f;
// to a normal value with an exponent of 0x01 and a mantissa of 0x00. return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, #else
// incrementing it causes it to become an exponent of 0xFF and a mantissa float2_t ret{
// of 0x00, which is Inf, the next higher value to the unrounded value. utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
bool flag0 = ~u.int32 & 0x7f800000; x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
// When all of the exponent bits are 1, the value is Inf or NaN. x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero return ret;
// mantissa bit. Quiet NaN is indicated by the most significant mantissa #endif
// bit being 1. Signaling NaN is indicated by the most significant }
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16); // convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
float scale = 1.0f;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
} }
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed /**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type<f6_t>(x / scale);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
});
return out.f6_vector;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template <> template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x) inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{ {
float x_fp32 = static_cast<float>(x); #if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
return bf16_convert_rtn<bhalf_t>(x_fp32); /**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
{
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
} }
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.f6_array[i]);
});
return out.float_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type<bf6_t>(x / scale);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
});
return out.bf6_vector;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template <>
inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template <>
inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
{
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.bf6_array[i]);
});
return out.float_vector;
#endif
}
#ifndef CK_CODE_GEN_RTC
template <typename Y, typename X, size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
{
for(size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
#endif
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{
for(size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
} // namespace ck } // namespace ck
# ck_tile [Back to the main page](../../README.md)
# Composable Kernel Tile
## concept ## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator `ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time. - tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
...@@ -44,5 +45,8 @@ our implementation of different device operators. ...@@ -44,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]** **[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues. epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples ## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder. currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp" #include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...@@ -26,6 +27,7 @@ ...@@ -26,6 +27,7 @@
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp" #include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
...@@ -52,6 +54,8 @@ ...@@ -52,6 +54,8 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
...@@ -62,6 +66,7 @@ ...@@ -62,6 +66,7 @@
#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
/**
* @brief Enumeration describing static tile distribution patterns.
*
*/
enum struct tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked,
/**
* @brief Warp raked pattern.
*
*/
warp_raked,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked,
};
struct TileDistributionEncodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
{
};
// Thread raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
static constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
};
// Warp raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
};
// Block raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y1 = num_warps;
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
};
} // namespace ck_tile
...@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl { namespace impl {
// below type indicate the data type used for buffer load inline asm // below type indicate the data type used for buffer load inline asm
// clang-format off // clang-format off
...@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) ...@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8 // buffer load i8
CK_TILE_DEVICE_EXTERN int8_t CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
...@@ -1239,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1239,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert( static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
...@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_ ...@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires: // buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,18 +12,37 @@ ...@@ -12,18 +12,37 @@
namespace ck_tile { namespace ck_tile {
enum struct address_space_enum template <typename, bool>
struct safe_underlying_type;
template <typename T>
struct safe_underlying_type<T, true>
{
using type = std::underlying_type_t<T>;
};
template <typename T>
struct safe_underlying_type<T, false>
{
using type = void;
};
template <typename T>
using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
enum struct address_space_enum : std::uint16_t
{ {
generic, generic = 0,
global, global,
lds, lds,
sgpr, sgpr,
vgpr, constant,
vgpr
}; };
enum struct memory_operation_enum enum struct memory_operation_enum : std::uint16_t
{ {
set, set = 0,
atomic_add, atomic_add,
atomic_max, atomic_max,
add add
...@@ -73,6 +92,24 @@ CK_TILE_DEVICE void block_sync_lds() ...@@ -73,6 +92,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif #endif
} }
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load() CK_TILE_DEVICE void block_sync_lds_direct_load()
{ {
asm volatile("\ asm volatile("\
...@@ -91,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) ...@@ -91,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif #endif
} }
#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)(p); // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/numeric/vector_type.hpp"
...@@ -8,16 +8,75 @@ ...@@ -8,16 +8,75 @@
namespace ck_tile { namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) template <typename T, typename ComputeType>
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
{ {
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b)); return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
} }
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{ {
bf16x2_t rtn; bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]); rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]); rtn[1] = add<bf16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
{
bf16x4_t rtn;
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
rtn[2] = add<bf16_t, float>(a[2], b[2]);
rtn[3] = add<bf16_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
{
fp8x8_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
rtn[4] = add<fp8_t, float>(a[4], b[4]);
rtn[5] = add<fp8_t, float>(a[5], b[5]);
rtn[6] = add<fp8_t, float>(a[6], b[6]);
rtn[7] = add<fp8_t, float>(a[7], b[7]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
{
bf8x4_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
{
bf8x8_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
rtn[4] = add<bf8_t, float>(a[4], b[4]);
rtn[5] = add<bf8_t, float>(a[5], b[5]);
rtn[6] = add<bf8_t, float>(a[6], b[6]);
rtn[7] = add<bf8_t, float>(a[7], b[7]);
return rtn; return rtn;
} }
...@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x) ...@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
} while(cur_v.u32 != old_v); } while(cur_v.u32 != old_v);
} }
template <>
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union U64BF164_ADDR
{
uint64_t* u64_a;
bf16x4_t* bf164_a;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union U64BF164
{
uint64_t u64;
bf16x4_t bf164;
};
U64BF164_ADDR addr;
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164 cur_v;
cur_v.u64 = *addr.u64_a;
U64BF164 new_v_union;
uint64_t old_v, new_v;
do
{
// old 64 bits
old_v = cur_v.u64;
// Add elementwise in bf16
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
new_v = new_v_union.u64;
// Attempt the 64-bit CAS
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
{
union U32FP84_ADDR
{
uint32_t* u32_a;
fp8x4_t* fp84_a;
};
union U32FP84
{
uint32_t u32;
fp8x4_t fp84;
};
U32FP84_ADDR dword_addr;
U32FP84 cur_v;
U32FP84 new_;
uint32_t old_v, new_v;
dword_addr.fp84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
{
union U32BF84_ADDR
{
uint32_t* u32_a;
bf8x4_t* bf84_a;
};
union U32BF84
{
uint32_t u32;
bf8x4_t bf84;
};
U32BF84_ADDR dword_addr;
U32BF84 cur_v;
U32BF84 new_;
uint32_t old_v, new_v;
dword_addr.bf84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
//
// Atomic add for fp8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union U64FP88_ADDR
{
uint64_t* u64_a; // pointer to 64-bit integer
fp8x8_t* fp88_a; // pointer to fp8x8_t
};
union U64FP88
{
uint64_t u64;
fp8x8_t fp88;
};
U64FP88_ADDR dword_addr;
U64FP88 cur_v;
U64FP88 new_v_union;
uint64_t old_v, new_v;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr.fp88_a = p_dst;
// Initial read of 64 bits from memory
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
new_v = new_v_union.u64;
// Attempt 64-bit CAS
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
//
// Atomic add for bf8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
{
union U64BF88_ADDR
{
uint64_t* u64_a;
bf8x8_t* bf88_a;
};
union U64BF88
{
uint64_t u64;
bf8x8_t bf88;
};
U64BF88_ADDR dword_addr;
U64BF88 cur_v;
U64BF88 new_v_union;
uint64_t old_v, new_v;
dword_addr.bf88_a = p_dst;
// Read the original 64 bits
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
new_v = new_v_union.u64;
// 64-bit CAS loop
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <typename T, index_t N> template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{ {
...@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) ...@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) || (std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) || (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) || (std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)), (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
"wrong! not implemented"); (std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
"The granularity of the thread buffer is unsupported on the hardware!");
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
...@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) ...@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]); atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1, }
x.template get_as<bf16x2_t>()[I1]); else if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
x.template get_as<bf16x4_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, bf8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment