Unverified Commit 2a30cfdd authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents 9533a172 78195ccc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return integral_constant<decltype(X % Y), X % Y>{};
}
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
namespace ck {
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = ck::false_type;
using value_t = integral_constant<bool, false>;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
{
using value_t = ck::true_type;
using value_t = integral_constant<bool, true>;
using type = Op<Args...>;
};
} // namespace detail
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#endif
#include "ck/utility/common_header.hpp"
......@@ -28,6 +35,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
} // namespace ck
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
......@@ -39,3 +47,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
return os;
}
#endif
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -10,6 +10,10 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace ck {
// magic number division
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
......@@ -20,7 +24,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); };
......@@ -81,7 +85,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x)
......@@ -461,7 +465,6 @@ inline __host__ double expm1<double>(double x)
return std::expm1(x);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
......@@ -533,7 +536,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x)
{
......@@ -613,7 +616,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <>
inline __device__ half_t neg<half_t>(half_t x)
{
return __hneg(x);
return __hneg(static_cast<__half>(x));
};
template <typename T>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
return false;
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return result == 0b0;
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f4_t>(scale, data))
return 0.0f;
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
f6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
bf6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template <>
__host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
f6_t const data [[maybe_unused]])
{
// no inf representation for fp6
return false;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template <>
__host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
bf6_t const data [[maybe_unused]])
{
// no inf representation for bf6
return false;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f6_t>(scale, data))
return 0.0f;
f6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<bf6_t>(scale, data))
return 0.0f;
bf6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<bf6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type<f6_t>(value);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type<bf6_t>(value);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type_sr<f6_t>(value, seed);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type_sr<bf6_t>(value, seed);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
namespace fp8_impl {
#if CK_MX_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
unsigned int rng = 0,
float scale = 1.0f)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
} val;
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
fp8_storage_t v4i8[4];
} ret{};
// unsigned int ival = 0;
val.fval = v;
if constexpr(stochastic_rounding)
{
ret.ival =
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
: __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
i8data = ret.v4i8[0];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
i8data = ret.v4i8[0];
}
return i8data;
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
unsigned int rng = 0,
float scale = 1.0f)
{
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
} ret{};
if constexpr(stochastic_rounding)
{
fp8x2_storage_t f8x2;
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
else
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
return f8x2;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
return ret.v2f8x2(Number<0>{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
} // namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck::utils {
union cvt
{
float value_float;
uint32_t value_bitwise;
};
template <typename DTYPE>
inline bool getDataHasInf()
{
return DTYPE::dataInfo.hasInf;
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
}
template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
return get_exponent_value<T>(x) == 0;
}
template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
for(uint i = 0; i < NumericUtils<T>::mant; i++)
{
mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
x >>= 1;
}
return mantissa;
}
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp;
if(is_subnormal<T>(data))
d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
else
d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
float d_mant = get_mantissa_value<T>(data);
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T>
inline T convert_to_type(float value)
{
using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint32_t max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
bitwise_type sign =
t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
bitwise_type exp =
((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant) | mantissa;
}
else
{
if(!get_data_has_inf<T>())
{
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant);
}
}
}
const int mfmt = NumericUtils<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & NumericUtils<float>::head_mask;
mantissa = x & NumericUtils<float>::mant_mask;
exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
sign = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
bias = NumericUtils<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = NumericUtils<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return 0;
else
return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - NumericUtils<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return 0;
}
mantissa &= (1UL << NumericUtils<T>::mant) - 1;
return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(out_exponent << NumericUtils<T>::mant) | mantissa;
}
template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
T sign = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
double d_max_value = static_cast<double>(max_value);
double d_actual_max = static_cast<double>(actual_max);
double d_value = static_cast<double>(value);
double d_is = std::abs(d_max_value - d_actual_max);
double d_seed = static_cast<double>(seed);
double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
double thresh = UINT_MAX * d_prob;
if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t>::data_max_negative_normal_mask;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
else
{
if(!get_data_has_inf<T>())
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
}
uint32_t f32 = bit_cast<uint32_t>(value);
auto f32_mant = f32 & NumericUtils<float>::mant_mask;
auto head = f32 & NumericUtils<float>::head_mask;
auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp;
auto mant = f32_mant;
bool subnorm = false;
if(f32 == 0)
return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min)
{
mant = f32_mant;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else if(exp < NumericUtils<T>::unbiased_exp_min &&
NumericUtils<T>::exp < NumericUtils<float>::exp)
{
subnorm = true;
auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32)
{
mant = 0;
f32_mant = 0;
}
else
{
f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff;
}
exp = 0;
mant = f32_mant;
}
uint32_t sr_shift = NumericUtils<T>::sr_shift;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding
if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm)
biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck/utility/ignore.hpp>
#include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
using uint8_t = unsigned char;
using uint16_t = unsigned short;
using uint32_t = unsigned int;
#endif
namespace ck {
// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<float, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
......@@ -23,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<half_t, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
......@@ -40,18 +47,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
ck::enable_if_t<!(ck::is_same<float, T>{} || ck::is_same<half_t, T>{}), bool> = false>
ck::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
#ifdef __HIPCC_RTC__
static_cast<void>(id);
static_cast<void>(val);
static_cast<void>(seed);
#else
std::ignore = id;
std::ignore = val;
std::ignore = seed;
#endif
ck::ignore = id;
ck::ignore = val;
ck::ignore = seed;
return 0;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace ck {
// Declare a template function for scaled conversion
template <typename Y, typename X>
#if CK_USE_OCP_FP8
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#else
__host__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#endif
// convert f8_ocp_t to fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float scaled_type_convert<float, f8_ocp_t>(e8m0_bexp_t scale, f8_ocp_t x)
#else
inline __host__ float scaled_type_convert<float, f8_ocp_t>(e8m0_bexp_t scale, f8_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.data);
#else
return type_convert<float>(scale) * type_convert<float>(x);
#endif
}
// convert bf8_ocp_t to fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float scaled_type_convert<float, bf8_ocp_t>(e8m0_bexp_t scale,
bf8_ocp_t x)
#else
inline __host__ float scaled_type_convert<float, bf8_ocp_t>(e8m0_bexp_t scale, bf8_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.data);
#else
return type_convert<float>(scale) * type_convert<float>(x);
#endif
}
// convert 2 x f8_ocp_t to 2 x fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t scale,
f8x2_ocp_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, f8x2_ocp_t>(e8m0_bexp_t scale, f8x2_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<f8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<0>{}]),
scaled_type_convert<float>(scale, x.AsType<f8_ocp_t>()[Number<1>{}])};
#endif
}
// convert 2 x bf8_ocp_t to 2 x fp32
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t scale,
bf8x2_ocp_t x)
#else
inline __host__ float2_t scaled_type_convert<float2_t, bf8x2_ocp_t>(e8m0_bexp_t scale,
bf8x2_ocp_t x)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2_scaled<bf8_ocp_t::default_interpret>(
type_convert<float>(scale), x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<0>{}]),
scaled_type_convert<float>(scale, x.AsType<bf8_ocp_t>()[Number<1>{}])};
#endif
}
// convert 16 x f8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float16_t scaled_type_convert<float16_t, f8x16_ocp_t>(e8m0_bexp_t scale,
f8x16_ocp_t x)
#else
inline __host__ float16_t scaled_type_convert<float16_t, f8x16_ocp_t>(e8m0_bexp_t scale,
f8x16_ocp_t x)
#endif
{
union
{
f8x16_ocp_t f8_1x16;
f8x2_ocp_t f8_2x8[8];
} in{x};
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}([&](auto i) {
out.float_2x8[i] = scaled_type_convert<float2_t, f8x2_ocp_t>(scale, in.f8_2x8[i]);
});
return out.float_1x16;
}
// convert 16 x bf8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float16_t scaled_type_convert<float16_t, bf8x16_ocp_t>(e8m0_bexp_t scale,
bf8x16_ocp_t x)
#else
inline __host__ float16_t scaled_type_convert<float16_t, bf8x16_ocp_t>(e8m0_bexp_t scale,
bf8x16_ocp_t x)
#endif
{
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} in{x};
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}([&](auto i) {
out.float_2x8[i] = scaled_type_convert<float2_t, bf8x2_ocp_t>(scale, in.bf8_2x8[i]);
});
return out.float_1x16;
}
// convert 32 x f8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f8x32_ocp_t>(e8m0_bexp_t scale,
f8x32_ocp_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, f8x32_ocp_t>(e8m0_bexp_t scale,
f8x32_ocp_t x)
#endif
{
union
{
f8x32_ocp_t f8_1x32;
f8x16_ocp_t f8_16x2[2];
} in{x};
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}([&](auto i) {
out.float_16x2[i] = scaled_type_convert<float16_t, f8x16_ocp_t>(scale, in.f8_16x2[i]);
});
return out.float_1x32;
}
// convert 32 x bf8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp_t scale,
bf8x32_ocp_t x)
#else
inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp_t scale,
bf8x32_ocp_t x)
#endif
{
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} in{x};
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}([&](auto i) {
out.float_16x2[i] = scaled_type_convert<float16_t, bf8x16_ocp_t>(scale, in.bf8_16x2[i]);
});
return out.float_1x32;
}
// convert fp32 to fp8
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32 to bf8
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale,
float x)
#else
inline __host__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x2 to fp8x2
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#else
inline __host__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, float2_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x2_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x2 to bf8x2
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#else
inline __host__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x2_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x16_ocp_t
scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ f8x16_ocp_t scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x16_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x16_ocp_t
scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ bf8x16_ocp_t scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x16_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x32_ocp_t
scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ f8x32_ocp_t scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<f8x32_ocp_t>(x, type_convert<float>(scale));
#endif
}
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x32_ocp_t
scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ bf8x32_ocp_t scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale));
#else
return mxf8_convert_rne<bf8x32_ocp_t>(x, type_convert<float>(scale));
#endif
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
f4x2_t x)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m0_bexp_t scale,
f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x, type_convert<float>(scale));
#else
return f4_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a 6-bit floating-point value (f6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The f6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template <>
inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x)
{
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<f6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale,
f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<f6_t>(scale, in.f6_array[i]); });
return out.float_vector;
#endif
}
/**
* @brief Converts a 6-bit floating-point value (bf6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The bf6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template <>
inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x)
{
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale,
bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<bf6_t>(scale, in.bf6_array[i]); });
return out.float_vector;
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (f6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (f6_t).
*/
template <>
inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template <>
inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (bf6_t).
*/
template <>
inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t scale, float x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template <>
inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#endif
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
......@@ -903,6 +909,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
......@@ -914,3 +921,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
return os;
}
#endif
#endif
......@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
......@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <
typename... Ys,
template <typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
......@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return y;
}
template <
typename... Ys,
template <typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
......@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return y;
}
template <
typename... Xs,
template <typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
......@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return r;
}
template <
typename... Xs,
template <typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
......@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return r;
}
template <
typename... Xs,
template <typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
......@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
// MultiIndex = scalar * MultiIndex
template <typename... Xs,
typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
......@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
// MultiIndex = MultiIndex * scalar
template <typename... Xs,
typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
{
return a * x;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#include "functional4.hpp"
#include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp"
#endif
namespace ck {
......@@ -158,13 +164,18 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
using is_tuple = decltype(ck::declval<T&>().IsTuple());
#endif
#endif
template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
#ifndef CK_CODE_GEN_RTC
return (is_detected<is_tuple, Ts>::value || ...);
#endif
}
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace ck {
#ifdef __HIPCC_RTC__
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
#ifdef CK_CODE_GEN_RTC
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
......@@ -75,7 +73,6 @@ struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
......@@ -107,7 +104,6 @@ constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
......@@ -115,17 +111,17 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
}
template <class T>
struct is_const : false_type
struct is_const : public integral_constant<bool, false>
{
};
template <class T>
struct is_const<const T> : true_type
struct is_const<const T> : public integral_constant<bool, true>
{
};
template <class T>
inline constexpr bool is_const_v = is_const<T>::value;
template <class T>
template <typename T>
inline constexpr bool is_reference_v = is_reference<T>::value;
template <class T>
......@@ -140,15 +136,13 @@ struct remove_const<const T>
};
template <class T>
using remove_const_t = typename remove_const<T>::type;
template <class T>
inline constexpr bool is_class_v = is_class<T>::value;
template <class T>
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
template <class...>
using void_t = void;
// template <typename T>
// T&& declval() noexcept;
template <class T, class U = T&&>
U private_declval(int);
......@@ -159,12 +153,12 @@ T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class...>
using void_t = void;
#else
#include <utility>
#include <type_traits>
using std::declval;
using std::false_type;
using std::forward;
using std::is_base_of;
using std::is_class;
......@@ -180,9 +174,8 @@ using std::remove_const_t;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::true_type;
using std::void_t;
#endif
#endif
template <typename X, typename Y>
......@@ -195,15 +188,117 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_floating_point : public integral_constant<bool, false>
{
};
template <>
struct is_floating_point<float> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<double> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<long double> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_integral : public integral_constant<bool, false>
{
};
template <>
struct is_integral<int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<signed char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<wchar_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char16_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char32_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<bool> : public integral_constant<bool, true>
{
};
template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename X, typename Y>
inline constexpr bool is_base_of_v = is_base_of<X, Y>::value;
template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
......@@ -221,5 +316,4 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
return __builtin_bit_cast(Y, x);
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace ck {
// Define the common macro for gfx94x models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__
#endif
namespace {
namespace details {
[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
half2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
return amd_assembly_pk_add_f16(x, y);
}
} // namespace details
} // namespace
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
// Nan check
if(x != x)
{
return uint16_t(0x7FC0);
}
union
{
float fp32;
uint32_t int32;
} u = {x};
const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
......@@ -51,17 +110,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32;
}
// convert fp32 to bfp16
// convert fp32 to bfp16, round to nearest even
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
#if CK_USE_RNE_BF16_CONVERSION
return bf16_convert_rtn<bhalf_t>(x);
#else
return uint16_t(u.int32 >> 16);
#endif
}
// convert bfp16 to fp16 via fp32
......@@ -100,6 +157,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
......@@ -163,10 +232,14 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__)
union
{
......@@ -189,36 +262,46 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__)
union
{
......@@ -240,28 +323,36 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x));
return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
......@@ -271,7 +362,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -296,32 +387,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x));
return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
f8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{
#if defined(__gfx94__)
union
......@@ -345,44 +438,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<float,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{
#if defined(__gfx94__)
// convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x));
return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
return utils::cast_to_f8<half_t,
bf8_fnuz_t,
negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -392,30 +500,95 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{
#if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t res = {x_h, x_l};
#elif
float2_t res = {x_l, x_h};
#endif
return res;
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#else
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t res = {type_convert<bhalf_t>(x_h), type_convert<bhalf_t>(x_l)};
#else
bhalf2_t res = {type_convert<bhalf_t>(x_l), type_convert<bhalf_t>(x_h)};
#endif
return res;
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
......@@ -428,42 +601,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x);
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_t>(x);
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
float fval;
......@@ -473,31 +668,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x);
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_t>(x);
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif
}
......@@ -512,70 +718,1297 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
}
}
#endif
// convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type<f4_t>(x / scale);
#endif
}
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
{
for(std::size_t i = 0; i < NumElems; i++)
#if defined(__gfx950__)
union
{
y[i] = type_convert<Y>(x[i]);
}
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{}, tmp_values{};
// TODO: pack in a loop
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
// Convert fp32 to bf16 with RTN if higher precision is needed
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type<f4_t>(x[0] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[1] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[2] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[3] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[4] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[5] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[6] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[7] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[8] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[9] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[10] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[11] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[12] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[13] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[14] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[15] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[16] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[17] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[18] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[19] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[20] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[21] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[22] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[23] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[24] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[25] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[26] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[27] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[28] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[29] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[30] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[31] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
union
{
float float_array[2];
float2_t float2_array;
} float_values{{x}};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0}, tmp_values{0};
union
{
float2_t floatx2_array[16];
float32_t floatx32_array;
} float_values{{0}};
// TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t type_convert<f4x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
{
#if defined(__gfx950__)
union
{
float fp32;
uint32_t int32;
} u = {x};
float float_array[2];
float2_t float2_array;
} float_values{};
float scale = 1.0f;
float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else
float2_t ret{
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret;
#endif
}
return uint16_t(u.int32 >> 16);
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
float scale = 1.0f;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type<f6_t>(x / scale);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
});
return out.f6_vector;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{
float x_fp32 = static_cast<float>(x);
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
return bf16_convert_rtn<bhalf_t>(x_fp32);
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
{
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.f6_array[i]);
});
return out.float_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type<bf6_t>(x / scale);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
});
return out.bf6_vector;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template <>
inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template <>
inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
{
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.bf6_array[i]);
});
return out.float_vector;
#endif
}
#ifndef CK_CODE_GEN_RTC
template <typename Y, typename X, size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
{
for(size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
#endif
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{
for(size_t i = 0; i < NumElems; i++)
{
y[i] = type_convert<Y>(x[i]);
}
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment