Commit f23a2e2a authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents f3eb5a18 c0adab48
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp" #include "is_detected.hpp"
#endif
namespace ck { namespace ck {
...@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& ...@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const Tuple<Y&...>& ty) const Tuple<Y&...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -38,7 +40,7 @@ template <typename... X, typename... Y> ...@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple) ...@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
} }
} }
#ifndef CK_CODE_GEN_RTC
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(ck::declval<T&>().IsTuple());
#endif
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&) __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{ {
#ifndef CK_CODE_GEN_RTC
return (is_detected<is_tuple, Ts>::value || ...); return (is_detected<is_tuple, Ts>::value || ...);
#endif
} }
template <index_t depth = 0, typename T> template <index_t depth = 0, typename T>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp" #include "ck/utility/enable_if.hpp"
#include "ck/utility/enable_if.hpp" #include "ck/utility/integral_constant.hpp"
namespace ck { namespace ck {
#ifdef CK_CODE_GEN_RTC
template <typename X, typename Y> // NOLINTNEXTLINE
struct is_same : public integral_constant<bool, false> #define CK_BUILTIN_TYPE_TRAIT1(name) \
{ template <class T> \
}; struct name : bool_constant<__##name(T)> \
{ \
template <typename X> }
struct is_same<X, X> : public integral_constant<bool, true>
{ // NOLINTNEXTLINE
}; #define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
template <typename X, typename Y> struct name : bool_constant<__##name(T, U)> \
inline constexpr bool is_same_v = is_same<X, Y>::value; { \
}
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type; // NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <typename T> template <class... Ts> \
using remove_cv_t = typename std::remove_cv<T>::type; struct name : bool_constant<__##name(Ts...)> \
{ \
template <typename T> }
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
CK_BUILTIN_TYPE_TRAIT1(is_class);
template <typename T> CK_BUILTIN_TYPE_TRAIT1(is_pointer);
using remove_pointer_t = typename std::remove_pointer<T>::type; CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
template <typename T> CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <class T>
__host__ __device__ constexpr Y bit_cast(const X& x) struct remove_cv
{ {
static_assert(__has_builtin(__builtin_bit_cast), ""); using type = T;
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); };
return __builtin_bit_cast(Y, x); template <class T>
} struct remove_cv<const T> : remove_cv<T>
{
} // namespace ck };
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
typedef T type;
};
template <class T>
struct remove_pointer<T*>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* volatile>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const volatile>
{
typedef T type;
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <class T>
struct is_const : public integral_constant<bool, false>
{
};
template <class T>
struct is_const<const T> : public integral_constant<bool, true>
{
};
template <class T>
inline constexpr bool is_const_v = is_const<T>::value;
template <typename T>
inline constexpr bool is_reference_v = is_reference<T>::value;
template <class T>
struct remove_const
{
typedef T type;
};
template <class T>
struct remove_const<const T>
{
typedef T type;
};
template <class T>
using remove_const_t = typename remove_const<T>::type;
template <class T>
inline constexpr bool is_class_v = is_class<T>::value;
template <class T>
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
// template <typename T>
// T&& declval() noexcept;
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class...>
using void_t = void;
#else
#include <utility>
#include <type_traits>
using std::declval;
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_class_v;
using std::is_const_v;
using std::is_pointer;
using std::is_reference;
using std::is_reference_v;
using std::is_trivially_copyable;
using std::is_trivially_copyable_v;
using std::is_unsigned;
using std::remove_const_t;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::void_t;
#endif
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_floating_point : public integral_constant<bool, false>
{
};
template <>
struct is_floating_point<float> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<double> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<long double> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_integral : public integral_constant<bool, false>
{
};
template <>
struct is_integral<int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<signed char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<wchar_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char16_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char32_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<bool> : public integral_constant<bool, true>
{
};
template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename X, typename Y>
inline constexpr bool is_base_of_v = is_base_of<X, Y>::value;
template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename remove_pointer<T>::type;
template <typename T>
inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck
...@@ -5,15 +5,39 @@ ...@@ -5,15 +5,39 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp" #include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp" #include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace ck { namespace ck {
// Define the common macro for MI300 models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
namespace {
namespace details {
[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
half2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
return amd_assembly_pk_add_f16(x, y);
}
} // namespace details
} // namespace
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x); __host__ __device__ constexpr Y bf16_convert_rtn(X x);
...@@ -52,10 +76,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -52,10 +76,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false> ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -63,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -63,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type. // Convert X to Y, either X or Y is a const data type.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false> ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>; using NonConstY = ck::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>; using NonConstX = ck::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x)); return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
} }
...@@ -149,7 +173,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int ...@@ -149,7 +173,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -211,7 +235,11 @@ template <> ...@@ -211,7 +235,11 @@ template <>
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -251,7 +279,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x) ...@@ -251,7 +279,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t, return utils::cast_to_f8<half_t,
f8_fnuz_t, f8_fnuz_t,
negative_zero_nan, negative_zero_nan,
...@@ -265,7 +298,11 @@ template <> ...@@ -265,7 +298,11 @@ template <>
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -307,7 +344,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x ...@@ -307,7 +344,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t, return utils::cast_to_f8<half_t,
bf8_fnuz_t, bf8_fnuz_t,
negative_zero_nan, negative_zero_nan,
...@@ -502,13 +544,51 @@ template <> ...@@ -502,13 +544,51 @@ template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x) inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{ {
uint8_t x_u8 = ck::bit_cast<uint8_t>(x); uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f32 = ck::type_convert<float>(x_l); float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
auto h_f32 = ck::type_convert<float>(x_h); float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t res = {x_h, x_l};
#elif
float2_t res = {x_l, x_h};
#endif
return res;
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#else
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t res = {type_convert<bhalf_t>(x_h), type_convert<bhalf_t>(x_l)};
#else
bhalf2_t res = {type_convert<bhalf_t>(x_l), type_convert<bhalf_t>(x_h)};
#endif
return {l_f32, h_f32}; return res;
} }
template <> template <>
...@@ -629,20 +709,1294 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x) ...@@ -629,20 +709,1294 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif #endif
} }
template <typename Y, typename X, std::size_t NumElems> // convert fp32 to fp4 with rounding to nearest even
inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type<f4_t>(x / scale);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{}, tmp_values{};
// TODO: pack in a loop
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise =
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type<f4_t>(x[0] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[1] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[2] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[3] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[4] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[5] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[6] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[7] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[8] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[9] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[10] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[11] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[12] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[13] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[14] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[15] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[16] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[17] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[18] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[19] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[20] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[21] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[22] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[23] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[24] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[25] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[26] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[27] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[28] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[29] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[30] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type<f4_t>(x[31] / scale);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4_t f4_array[4];
} value{0};
union
{
float float_array[2];
float2_t float2_array;
} float_values{{x}};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0];
#else
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4x2_array[0];
#else
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#if defined(__gfx950__)
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0}, tmp_values{0};
union
{
float2_t floatx2_array[16];
float32_t floatx32_array;
} float_values{{0}};
// TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0);
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0);
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0);
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0);
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0);
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0);
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0);
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0);
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0);
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0);
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0);
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0);
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0);
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0);
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0);
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
return f4_values.f4x32_array;
#else
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{0};
// TODO: pack in a loop
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
f4_values.bitwise <<= 4;
f4_values.bitwise |= tmp;
return f4_values.f4x32_array;
#endif
}
// convert fp32 to fp4
template <>
inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template <>
inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>
inline __host__ __device__ f4x32_t type_convert<f4x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F4_CONVERSION
return f4_convert_sr(x);
#else
return f4_convert_rne(x);
#endif
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float scale = 1.0f;
float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
{
#if defined(__gfx950__)
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else
float2_t ret{
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
{
#if defined(__gfx950__)
union
{
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
float scale = 1.0f;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
return ret;
#else
union
{
float32_t float32_array;
float float_array[32];
} float_values{};
union
{
__uint128_t bitwise;
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
}
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type<f6_t>(x / scale);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
});
return out.f6_vector;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template <>
inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
{
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.f6_array[i]);
});
return out.float_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type<bf6_t>(x / scale);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
});
return out.bf6_vector;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template <>
inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template <>
inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
{
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.bf6_array[i]);
});
return out.float_vector;
#endif
}
#ifndef CK_CODE_GEN_RTC
template <typename Y, typename X, size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x) const std::array<X, NumElems>& x)
{ {
for(std::size_t i = 0; i < NumElems; i++) for(size_t i = 0; i < NumElems; i++)
{ {
y[i] = type_convert<Y>(x[i]); y[i] = type_convert<Y>(x[i]);
} }
} }
#endif
template <typename Y, typename X, index_t NumElems> template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x) inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{ {
for(std::size_t i = 0; i < NumElems; i++) for(size_t i = 0; i < NumElems; i++)
{ {
y[i] = type_convert<Y>(x[i]); y[i] = type_convert<Y>(x[i]);
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp" #include "ck_tile/core/tensor/buffer_view.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/numeric/vector_type.hpp"
...@@ -8,16 +8,75 @@ ...@@ -8,16 +8,75 @@
namespace ck_tile { namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) template <typename T, typename ComputeType>
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
{ {
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b)); return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
} }
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{ {
bf16x2_t rtn; bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]); rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]); rtn[1] = add<bf16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
{
bf16x4_t rtn;
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
rtn[2] = add<bf16_t, float>(a[2], b[2]);
rtn[3] = add<bf16_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
{
fp8x8_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
rtn[4] = add<fp8_t, float>(a[4], b[4]);
rtn[5] = add<fp8_t, float>(a[5], b[5]);
rtn[6] = add<fp8_t, float>(a[6], b[6]);
rtn[7] = add<fp8_t, float>(a[7], b[7]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
{
bf8x4_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
{
bf8x8_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
rtn[4] = add<bf8_t, float>(a[4], b[4]);
rtn[5] = add<bf8_t, float>(a[5], b[5]);
rtn[6] = add<bf8_t, float>(a[6], b[6]);
rtn[7] = add<bf8_t, float>(a[7], b[7]);
return rtn; return rtn;
} }
...@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x) ...@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
} while(cur_v.u32 != old_v); } while(cur_v.u32 != old_v);
} }
template <>
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union U64BF164_ADDR
{
uint64_t* u64_a;
bf16x4_t* bf164_a;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union U64BF164
{
uint64_t u64;
bf16x4_t bf164;
};
U64BF164_ADDR addr;
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164 cur_v;
cur_v.u64 = *addr.u64_a;
U64BF164 new_v_union;
uint64_t old_v, new_v;
do
{
// old 64 bits
old_v = cur_v.u64;
// Add elementwise in bf16
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
new_v = new_v_union.u64;
// Attempt the 64-bit CAS
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
{
union U32FP84_ADDR
{
uint32_t* u32_a;
fp8x4_t* fp84_a;
};
union U32FP84
{
uint32_t u32;
fp8x4_t fp84;
};
U32FP84_ADDR dword_addr;
U32FP84 cur_v;
U32FP84 new_;
uint32_t old_v, new_v;
dword_addr.fp84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
{
union U32BF84_ADDR
{
uint32_t* u32_a;
bf8x4_t* bf84_a;
};
union U32BF84
{
uint32_t u32;
bf8x4_t bf84;
};
U32BF84_ADDR dword_addr;
U32BF84 cur_v;
U32BF84 new_;
uint32_t old_v, new_v;
dword_addr.bf84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
//
// Atomic add for fp8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union U64FP88_ADDR
{
uint64_t* u64_a; // pointer to 64-bit integer
fp8x8_t* fp88_a; // pointer to fp8x8_t
};
union U64FP88
{
uint64_t u64;
fp8x8_t fp88;
};
U64FP88_ADDR dword_addr;
U64FP88 cur_v;
U64FP88 new_v_union;
uint64_t old_v, new_v;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr.fp88_a = p_dst;
// Initial read of 64 bits from memory
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
new_v = new_v_union.u64;
// Attempt 64-bit CAS
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
//
// Atomic add for bf8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
{
union U64BF88_ADDR
{
uint64_t* u64_a;
bf8x8_t* bf88_a;
};
union U64BF88
{
uint64_t u64;
bf8x8_t bf88;
};
U64BF88_ADDR dword_addr;
U64BF88 cur_v;
U64BF88 new_v_union;
uint64_t old_v, new_v;
dword_addr.bf88_a = p_dst;
// Read the original 64 bits
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
new_v = new_v_union.u64;
// 64-bit CAS loop
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <typename T, index_t N> template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{ {
...@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) ...@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) || (std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) || (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) || (std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)), (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
"wrong! not implemented"); (std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
"The granularity of the thread buffer is unsupported on the hardware!");
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
...@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) ...@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]); atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1, }
x.template get_as<bf16x2_t>()[I1]); else if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
x.template get_as<bf16x4_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, bf8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
} }
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) defined(__gfx942__) || defined(__gfx950__)
#define __gfx9__ #define __gfx9__
#endif #endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...@@ -144,6 +144,10 @@ ...@@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif #endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point // buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
...@@ -230,3 +234,15 @@ ...@@ -230,3 +234,15 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 #define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif #endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
...@@ -14,6 +14,12 @@ ...@@ -14,6 +14,12 @@
#pragma once #pragma once
#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
#define CK_TILE_FP8_CVT_DEVICE 0
#endif
namespace ck_tile { namespace ck_tile {
// fp8 rounding modes // fp8 rounding modes
...@@ -25,15 +31,26 @@ enum class fp8_rounding_mode ...@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
stochastic stochastic
}; };
/**
* \brief FP8 interpretation used in conversion algorithms
*/
enum class fp8_interpretation
{
E4M3_OCP = 0, // OCP FP8 E4M3
E5M2_OCP = 1, // OCP BF8 E5M2
E4M3_FNUZ = 2, // FNUZ FP8 E4M3
E5M2_FNUZ = 3, // FNUZ BF8 E5M2
};
/* /*
* ______________NANOO_________________ | ______________IEEE________________ * ______________FNUZ_________________ | ______________OCP________________
* e4m3 e5m2 | e4m3 e5m2 * e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15 * bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00 * inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11} * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00 * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344) * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344) * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05 * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00 * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05) * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
...@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t ...@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
{ {
static constexpr int exponent = 4; static constexpr int exponent = 4;
static constexpr int mantissa = 3; static constexpr int mantissa = 3;
#if defined(__gfx94__) #if CK_TILE_USE_OCP_FP8
static constexpr int bias = 1 << (exponent - 1); // NANOO static constexpr int bias = 7; // OCP
#else #else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE static constexpr int bias = 8; // FNUZ
#endif #endif
using raw_type = uint8_t; using raw_type = uint8_t;
raw_type data; raw_type data;
...@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t ...@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
{ {
static constexpr int exponent = 5; static constexpr int exponent = 5;
static constexpr int mantissa = 2; static constexpr int mantissa = 2;
#if defined(__gfx94__) #if CK_TILE_USE_OCP_FP8
static constexpr int bias = 1 << (exponent - 1); // NANOO static constexpr int bias = 15; // OCP
#else #else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE static constexpr int bias = 16; // FNUZ
#endif #endif
using raw_type = uint8_t; using raw_type = uint8_t;
raw_type data; raw_type data;
...@@ -183,501 +200,727 @@ struct native_t<bf8_t> ...@@ -183,501 +200,727 @@ struct native_t<bf8_t>
}; };
#else #else
using fp8_t = _BitInt(8); using fp8_t = _BitInt(8);
using fp8_raw_t = uint8_t; using fp8_raw_t = uint8_t;
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
using bf8_raw_t = uint8_t; using bf8_raw_t = uint8_t;
#endif #endif
// below is sw fp8 conversion, not utilizing hw instruction template <typename T>
namespace impl { struct numeric_traits;
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch> template <>
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) struct numeric_traits<fp8_t>
{ {
// fp8/bf8 exponent/mantissa layout using bitwise_type = fp8_raw_t;
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant; static constexpr int exp = 4;
static constexpr int mant = 3;
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 7;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_OCP;
#else
static constexpr int bias = 8;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
};
// original type exponent/mantissa layout template <>
constexpr int in_exp = numeric_traits<X>::exp; struct numeric_traits<bf8_t>
constexpr int in_mant = numeric_traits<X>::mant; {
using bitwise_type = bf8_raw_t;
int exponent, bias; static constexpr int exp = 5;
uint32_t head, mantissa, sign; static constexpr int mant = 2;
// nan code is same for float and half #if CK_TILE_USE_OCP_FP8
#if CK_TILE_USE_CUSTOM_DATA_TYPE static constexpr int bias = 15;
constexpr Y nan_code = static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_OCP;
numeric<Y>::quiet_NaN(); // __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
#else #else
constexpr Y nan_code = 0x80; static constexpr int bias = 16;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
#endif #endif
static constexpr uint8_t abs_mask = 0x7F;
};
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
template <typename SrcT, typename DstT, bool clip = true, bool stoch = false>
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
{
static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
"DstT type must be fp8 or bf8.");
constexpr uint32_t nan_mask = numeric_traits<X>::nan_mask; constexpr bool is_half = std::is_same<SrcT, half_t>::value;
constexpr bool is_float = std::is_same<SrcT, float>::value;
static_assert(is_half || is_float, "Only half and float can be cast to f8");
// convert to bitwise // fp8/bf8 type exponent/mantissa layout
using T_bitwise = typename numeric_traits<X>::bitwise_type; constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x)); constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr bool is_fnuz =
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
// unpack the input, depends on datatype constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
head = x_bitwise & numeric_traits<X>::head_mask; constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
mantissa = x_bitwise & numeric_traits<X>::mant_mask;
exponent = (head >> in_mant) & numeric_traits<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = numeric_traits<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
if constexpr(negative_zero_nan) unsigned long long head, mantissa;
int exponent, bias;
unsigned int sign;
unsigned long long fInf, abs_mask;
head = src_bitwise & numeric_traits<SrcT>::head_mask;
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
sign = head >> (SrcT_exp + SrcT_mant);
bias = numeric_traits<SrcT>::bias;
fInf = numeric_traits<SrcT>::Inf;
abs_mask = numeric_traits<SrcT>::abs_mask;
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{ {
if((x_bitwise & nan_mask) == nan_mask) signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
return nan_code; nan = 0x80;
} }
else else
{ {
if((x_bitwise & nan_mask) == nan_mask) if constexpr(DstT_exp == 4)
return signed_inf + (mantissa != 0 ? 1 : 0); { // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
if constexpr(is_float)
{
if constexpr(DstT_exp == 5)
{
ifmax = 0x47600000;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x43700000;
}
else
{
ifmax = 0x43E00000;
}
}
}
else if constexpr(is_half)
{
if constexpr(DstT_exp == 5)
{
ifmax = 0x7B00;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x5B80;
}
else
{
ifmax = 0x5F00;
}
}
} }
// check if x is 0.0 // Deal with inf and NaNs
if(x_bitwise == 0) if((src_bitwise & fInf) == fInf)
return __builtin_bit_cast(Y, static_cast<uint8_t>(0)); {
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
if((src_bitwise & abs_mask) > ifmax)
{
return signed_inf;
}
if(src_bitwise == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of implict 1 // First need to check if it is normal or denorm as there is a difference of
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust // to mantissa and truncate. And for RNE, no need to add rng. Then probably
// exponent and mantissa again3 // need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); // bits
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding // f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted // the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff; int act_exponent, f8_exponent, exponent_diff;
if(exponent == 0) if(exponent == 0)
{ // fp32/fp16 is in denormal. { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has mostly concern fp16 here. In this case, f8 is usually in denormal. But there
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers exponent bias 16. It means that there are some numbers in fp16 denormal but they
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
In this case, the fp16 mantissa should be shift left by 1 */ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1; act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent - exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal act_exponent; // actual exponent is exponent-bias+1 as it is denormal
} }
else else
{ // fp32/fp16 is normal with implicit 1 { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias; act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent) if(act_exponent <= f8_denormal_act_exponent)
{ {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range. /* This is the case where fp32/fp16 is normal but it is in f8 denormal
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1, actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1. Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent; exponent_diff = f8_denormal_act_exponent - act_exponent;
} }
else else
{ // both fp32/fp16 and f8 are in normal range { // both fp32/fp16 and f8 are in normal range
exponent_diff = exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
0; // exponent_diff=0 does not mean there is no difference for this case, // for this case, act_exponent could be larger. Just
// act_exponent could be larger. Just that it does not need shift mantissa // that it does not need shift mantissa
} }
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
} }
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1)); (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we /* This part is a bit tricky. The judgment of whether it is a tie needs to be
shift right as shift right could rip off some residual part and make something not midpoint look done before we shift right as shift right could rip off some residual part and
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than make something not midpoint look like midpoint. For example, the fp16 number
midpoint, but after shift right by 4 bits, it would look like midpoint. */ 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0) if(exponent_diff > 0)
mantissa >>= exponent_diff; mantissa >>= exponent_diff;
else if(exponent_diff == -1) else if(exponent_diff == -1)
mantissa <<= -exponent_diff; mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant); bool implicit_one = mantissa & (1ull << SrcT_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent // if there is no implicit 1, it means the f8 is denormal and need to adjust
out_exponent = // to denorm exponent
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted // Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
bool odd = bool odd =
mantissa & mantissa & (1ull << (SrcT_mant -
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
// Now we deal with overflow // Now we deal with overflow
if(out_exponent == 0) if(f8_exponent == 0)
{ {
if((1 << in_mant) & mantissa) if((1ull << SrcT_mant) & mantissa)
{ {
out_exponent = 1; // denormal overflow to become normal, promote exponent f8_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
} }
} }
else else
{ {
if((1 << (in_mant + 1)) & mantissa) if((1ull << (SrcT_mant + 1)) & mantissa)
{ {
mantissa >>= 1; mantissa >>= 1;
out_exponent++; f8_exponent++;
// No need to make 1 implicit now as it will be addressed later
} }
} }
mantissa >>= (in_mant - out_mant); mantissa >>= (SrcT_mant - DstT_mant);
if(out_exponent > max_exp) // above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << DstT_exp) - 1;
if(f8_exponent > max_exp)
{ {
if(clip) if constexpr(clip)
{ {
mantissa = (1 << out_mant) - 1; mantissa = (1 << DstT_mant) - 1;
out_exponent = max_exp; f8_exponent = max_exp;
} }
else else
{ {
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf)); return signed_inf;
} }
} }
// check if x is 0.0 or -0.0 if(f8_exponent == 0 && mantissa == 0)
if(out_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7);
return __builtin_bit_cast( mantissa &= (1 << DstT_mant) - 1;
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant)))); return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
mantissa &= (1 << out_mant) - 1;
return __builtin_bit_cast(Y,
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
(out_exponent << out_mant) | mantissa));
} }
template <typename X, typename Y, bool negative_zero_nan> template <typename SrcT, typename DstT, bool clip = true>
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
{ {
// fp8/bf8 exponent/mantissa layout static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
constexpr int in_exp = numeric_traits<X>::exp; "SrcT type must be fp8 or bf8.");
constexpr int in_mant = numeric_traits<X>::mant; constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
// resulting type exponent/mantissa layout constexpr bool is_fnuz =
constexpr int out_exp = numeric_traits<Y>::exp; (numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
constexpr int out_mant = numeric_traits<Y>::mant; (numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
constexpr bool is_half = std::is_same<DstT, half_t>::value;
// prepare the codes constexpr bool is_float = std::is_same<DstT, float>::value;
constexpr uint8_t nan_code = 0x80; static_assert(is_half || is_float, "DstT type must be half_t or float.");
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_traits<Y>::bitwise_type; // destination type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr T_bitwise Inf_bitwise = numeric_traits<Y>::Inf; constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr T_bitwise NegInf_bitwise = numeric_traits<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_traits<Y>::NaN; constexpr DstT fInf = bit_cast<DstT>(numeric_traits<DstT>::Inf);
constexpr T_bitwise Neg0_bitwise = numeric_traits<Y>::Neg0; constexpr DstT fNegInf = bit_cast<DstT>(numeric_traits<DstT>::NegInf);
constexpr DstT fNaN = bit_cast<DstT>(numeric_traits<DstT>::NaN);
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise)); constexpr DstT fNeg0 = bit_cast<DstT>(numeric_traits<DstT>::Neg0);
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise)); DstT fmax{0}, fmin{0};
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise)); // Max number in e5m2 57344
if constexpr(is_half)
// check if x is 0.0 {
if(x_raw == 0) fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x7B00));
return static_cast<Y>(0); fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xFB00));
}
// unpack the input else if constexpr(is_float)
uint32_t sign = x_raw >> (in_exp + in_mant); {
uint32_t mantissa = x_raw & ((1 << in_mant) - 1); fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x47600000));
int exponent = (x_raw & 0x7F) >> in_mant; fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xC7600000));
}
constexpr int exp_low_cutoff = if(x == 0)
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); {
T_bitwise retval; return 0;
}
if constexpr(negative_zero_nan) unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & 0x7F) >> SrcT_mant;
if constexpr(is_fnuz)
{ {
if(x_raw == nan_code) if(x == 0x80)
return NaN; {
return fNaN;
}
} }
else else
{ {
if(x_raw == nan_code) if(x == 0x80)
return Neg0; {
if(exponent == ((1 << in_exp) - 1)) return fNeg0;
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; }
if constexpr(SrcT_exp == 4)
{ // e4m3
if((x & 0x7F) == 0x7F)
{
return fNaN;
}
}
else if((x & 0x7C) == 0x7C)
{ // e5m2
if((x & 0x3) == 0)
{
if constexpr(clip)
{
return sign ? fmin : fmax;
}
return sign ? fNegInf : fInf;
}
return fNaN;
}
} }
if((numeric_traits<Y>::mant == 10) && (numeric_traits<X>::mant == 2) && !negative_zero_nan) typename numeric_traits<DstT>::bitwise_type retval;
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
{ {
retval = x_raw; retval = x << 8;
retval <<= 8; return bit_cast<DstT>(retval);
return *(reinterpret_cast<const Y*>(&retval));
} }
const int exp_low_cutoff =
(1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
// subnormal input // subnormal input
if(exponent == 0) if(exponent == 0)
{ {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above int sh = 1 + clz(mantissa) - (32 - SrcT_mant);
int sh = 1 + clz(mantissa) - (32 - in_mant);
mantissa <<= sh; mantissa <<= sh;
exponent += 1 - sh; exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1); mantissa &= ((1ull << SrcT_mant) - 1);
} }
exponent += exp_low_cutoff - 1; exponent += exp_low_cutoff - 1;
mantissa <<= out_mant - in_mant; mantissa <<= DstT_mant - SrcT_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true) // subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true)
if(exponent <= 0) if(exponent <= 0)
{ {
mantissa |= 1 << out_mant; mantissa |= 1 << DstT_mant;
mantissa >>= 1 - exponent; mantissa >>= 1 - exponent;
exponent = 0; exponent = 0;
} }
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng); return bit_cast<DstT>(retval);
} }
template <typename X, typename Y, bool negative_zero_nan> template <typename X, typename Y, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_from_f8(X x) CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{ {
// check datatype return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
} }
} // namespace impl
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) #if CK_TILE_FP8_CVT_DEVICE
/**
* @brief Cast float to fp8/bf8 using device conversion instructions
*/
template <fp8_interpretation interpret, bool saturate, bool stochastic_rounding = false>
CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
{ {
constexpr int seed = 42; uint8_t i8data;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
{ {
float fval; float fval;
uint32_t i32val; unsigned int i32val;
uint8_t i8val[4]; // not endian independent unsigned char i8val[4]; // NOTE: not endian independent
} val; } val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) unsigned int ival = 0;
{ val.fval = v;
constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x); if constexpr(saturate)
#if defined(__gfx94__)
union
{ {
float fval; if constexpr(interpret == fp8_interpretation::E4M3_FNUZ)
uint32_t i32val; {
uint8_t i8val[4]; // not endian independent if((val.i32val & 0x7F800000) != 0x7F800000)
} val; { /// propagate NAN/INF, no clipping
val.fval = x; val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
uint32_t ival = 0; }
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos }
val.i32val = ival; else if constexpr(interpret == fp8_interpretation::E4M3_OCP)
return val.i8val[0]; // little endian { // OCP type
#else if((val.i32val & 0x7F800000) != 0x7F800000)
constexpr bool negative_zero_nan = true; { /// propagate NAN/INF, no clipping
constexpr bool clip = true; val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; }
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float, }
bf8_t, else
negative_zero_nan, {
clip, if((val.i32val & 0x7F800000) != 0x7F800000)
(rm == fp8_rounding_mode::stochastic)>(x, rng)); { /// propagate NAN/INF, no clipping
#endif val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
}
}
}
if constexpr(stochastic_rounding)
{
ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
i8data = val.i8val[0]; // little endian
}
else
{ // RNE CVT
ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
val.fval,
ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
}
return i8data;
} }
#endif // CK_TILE_FP8_CVT_DEVICE
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) } // namespace impl
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic
* rounding.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may
* involve clipping and uses a pseudo-random number generator for the stochastic rounding.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template <typename SrcT, typename DstT>
CK_TILE_HOST_DEVICE typename numeric_traits<DstT>::bitwise_type float_to_fp8_sr_raw(SrcT x)
{ {
#if defined(__gfx94__) constexpr bool clip = true;
float max_fp8 = 240.0f; constexpr int seed = 42;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); uint32_t rng = prand_generator_t<SrcT, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
union #if CK_TILE_FP8_CVT_DEVICE
{ return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, true>(x, rng);
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else #else
constexpr bool negative_zero_nan = true; return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
constexpr bool clip = true; impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif #endif
} }
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to
* nearest even.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may involve clipping.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template <typename SrcT, typename DstT>
CK_TILE_HOST_DEVICE typename numeric_traits<DstT>::bitwise_type float_to_fp8_rtn_raw(SrcT x)
{ {
#if defined(__gfx94__) constexpr bool clip = true;
union #if CK_TILE_FP8_CVT_DEVICE
{ return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, false>(x, 0);
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else #else
constexpr bool negative_zero_nan = true; return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
constexpr bool clip = true; impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif #endif
} }
// clang-format off template <fp8_rounding_mode rounding>
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>) CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
{ {
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); if constexpr(rounding == fp8_rounding_mode::standard)
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); {
else return fp8_raw_t{0}; return float_to_fp8_rtn_raw<float, fp8_t>(x);
}
else if constexpr(rounding == fp8_rounding_mode::stochastic)
{
return float_to_fp8_sr_raw<float, fp8_t>(x);
}
else
{
return fp8_raw_t{0};
}
} }
template<fp8_rounding_mode rounding> template <fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>) CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
{ {
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); if constexpr(rounding == fp8_rounding_mode::standard)
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); {
else return bf8_raw_t{0}; return float_to_fp8_rtn_raw<float, bf8_t>(x);
}
else if constexpr(rounding == fp8_rounding_mode::stochastic)
{
return float_to_fp8_sr_raw<float, bf8_t>(x);
}
else
{
return bf8_raw_t{0};
}
} }
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
{ {
#if defined(__gfx94__) #if CK_TILE_FP8_CVT_DEVICE
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(x); uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(bit_cast<fp8_t>(x));
#endif #endif
} }
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
{ {
#if defined(__gfx94__) #if CK_TILE_FP8_CVT_DEVICE
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(x); uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bit_cast<bf8_t>(x));
#endif #endif
} }
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)> template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {}) CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {})
{ {
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{})); return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
} }
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)> template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {}) CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {})
{ {
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{})); return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
} }
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast<fp8_raw_t>(x)); }
{
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
}
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast<bf8_raw_t>(x)); }
{
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
}
// clang-format on template <class T>
struct numeric;
template <typename T>
struct numeric_traits;
#if CK_TILE_USE_OCP_FP8
template <> template <>
struct numeric_traits<fp8_t> struct numeric<fp8_t>
{ {
static constexpr int exp = 4; // minimum finite value, or minimum positive normal value
static constexpr int mant = 3; CK_TILE_HOST_DEVICE static constexpr fp8_t min()
#if defined(__gfx94__) {
static constexpr int bias = 8; return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
#else }
static constexpr int bias = 7;
#endif // minumum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t max()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7e)); // 0b01111110 = 448
}
// difference between 1.0 and next representable f8 value (1.125)
// returns fp8_t(0.125)
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20)); // 0.125
}
// rounding error (0.0625)
// half of epsilon
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x18)); // 0.0625
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7F)); // 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xFF)); // 0b11111111
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
}
}; };
template <> template <>
struct numeric_traits<bf8_t> struct numeric<bf8_t>
{ {
static constexpr int exp = 5; // minimum finite value, or minimum positive normalized value for float
static constexpr int mant = 2; CK_TILE_HOST_DEVICE static constexpr bf8_t min()
#if defined(__gfx94__) {
static constexpr int bias = 16; return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
#else }
static constexpr int bias = 15; // IEEE
#endif
};
template <class T> // minumum finite value
struct numeric; CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t max()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7b)); // 0b01111011 = 57344
}
// difference between 1.0 and next representable bf8 value (1.25)
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34)); // 0.25
}
// rounding error (0.125)
// half of epsilon
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x30)); // 0.125
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7c)); // 0b01111100
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7F)); // 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
}
};
#else
template <> template <>
struct numeric<fp8_t> struct numeric<fp8_t>
{ {
...@@ -811,6 +1054,7 @@ struct numeric<bf8_t> ...@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0)); return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
} }
}; };
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE #if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t) CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
...@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t) ...@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif #endif
// math // math
CK_TILE_HOST_DEVICE template <typename T>
fp8_t abs(const fp8_t& x) CK_TILE_HOST_DEVICE T abs(const T& x)
{ {
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(bit_cast<fp8_raw_t>(x) & 0x7f)); static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
"Only fp8_t and bf8_t are supported");
return bit_cast<T>(static_cast<uint8_t>(bit_cast<uint8_t>(x) & numeric_traits<T>::abs_mask));
} }
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
bool isnan(const fp8_t& x) bool isnan(const fp8_t& x)
{ {
uint8_t xx = bit_cast<fp8_raw_t>(x); uint8_t xx = bit_cast<fp8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
}
#if CK_TILE_USE_OCP_FP8
return (xx & 0x7f) == 0x7f;
#else
return xx == 0x80;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); }; fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
...@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); } ...@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE CK_TILE_DEVICE
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); }; fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
#endif
CK_TILE_HOST_DEVICE
bf8_t abs(const bf8_t& x)
{
return bit_cast<bf8_t>(static_cast<fp8_raw_t>(bit_cast<bf8_raw_t>(x) & 0x7f));
}
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
bool isnan(const bf8_t& x) bool isnan(const bf8_t& x)
{ {
uint8_t xx = bit_cast<bf8_raw_t>(x); uint8_t xx = bit_cast<bf8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
#if CK_TILE_USE_OCP_FP8
return (xx & 0x7f) > 0x7c;
#else
return xx == 0x80;
#endif
} }
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); }; bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
...@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); } ...@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE CK_TILE_DEVICE
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); }; bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
#endif
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
...@@ -236,10 +236,11 @@ struct numeric_traits<half_t> ...@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
static constexpr uint16_t head_mask = 0xFC00; static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF; static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F; static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00; static constexpr uint16_t abs_mask = 0x7FFF;
static constexpr uint32_t NegInf = 0xFC00; static constexpr uint16_t Inf = 0x7C00;
static constexpr uint32_t NaN = 0x7C01; static constexpr uint16_t NegInf = 0xFC00;
static constexpr uint32_t Neg0 = 0x8000; static constexpr uint16_t NaN = 0x7C01;
static constexpr uint16_t Neg0 = 0x8000;
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -89,6 +89,7 @@ struct numeric_traits<float> ...@@ -89,6 +89,7 @@ struct numeric_traits<float>
static constexpr uint32_t head_mask = 0xFF800000; static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF; static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF; static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
static constexpr uint32_t Inf = 0x7F800000; static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t NaN = 0x7F800001;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace ck_tile {
// Packed 2xint4
struct pk_int4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_int4_t() : data{type{}} {}
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {}
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<pk_int4_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
{
constexpr uint8_t val = 0b10001000;
return pk_int4_t(bit_cast<int8_t>(val));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest()
{
constexpr uint8_t val = 0b10001000;
return pk_int4_t(bit_cast<int8_t>(val));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
{
constexpr uint8_t val = 0b01110111;
return pk_int4_t(bit_cast<int8_t>(val));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
};
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};
#elif
fp32x2_t res = {x_l, x_h};
#endif
return res;
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#elif
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
}
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
#elif
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
#endif
return res;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); ...@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif #endif
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp" #include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp" #include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp" #include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/device_memory.hpp"
...@@ -20,6 +21,7 @@ ...@@ -20,6 +21,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
......
...@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType = ...@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_relative_threshold(const int number_of_accumulations = 1) double get_relative_threshold(const int number_of_accumulations = 1)
{ {
using F8 = ck_tile::fp8_t; using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t; using BF16 = ck_tile::bf16_t;
using F32 = float; using F32 = float;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
...@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
...@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
...@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType = ...@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{ {
using F8 = ck_tile::fp8_t; using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t; using BF16 = ck_tile::bf16_t;
using F32 = float; using F32 = float;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
...@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
...@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
...@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
template <typename T>
struct IsCharArray : std::false_type
{
};
template <std::size_t N>
struct IsCharArray<char[N]> : std::true_type
{
};
template <std::size_t N>
struct IsCharArray<const char[N]> : std::true_type
{
};
template <std::size_t N>
struct IsCharArray<char (&)[N]> : std::true_type
{
};
template <std::size_t N>
struct IsCharArray<const char (&)[N]> : std::true_type
{
};
template <typename... Ts>
inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v<Ts, std::string_view> ||
IsCharArray<Ts>::value ||
std::is_same_v<Ts, char>)&&...);
template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
oss.str("");
(oss << ... << xs);
return oss.str();
}
template <std::size_t N>
[[nodiscard]] constexpr inline std::size_t getSize(char (&)[N]) noexcept
{
return N;
}
template <std::size_t N>
[[nodiscard]] constexpr inline std::size_t getSize(const char (&)[N]) noexcept
{
return N;
}
[[nodiscard]] constexpr inline std::size_t getSize(const char* s) noexcept
{
const char* end = s;
while(*end++ != 0) {}
return end - s - 1;
}
[[nodiscard]] constexpr inline std::size_t getSize(const char&) noexcept { return 1; }
[[nodiscard]] inline std::size_t getSize(const std::string& s) noexcept { return s.size(); }
[[nodiscard]] constexpr inline std::size_t getSize(const std::string_view& s) noexcept
{
return s.size();
}
template <typename... Ts>
auto concatInto(std::string& result, const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
{
const std::size_t space = (1 + ... + getSize(xs));
result.reserve(result.size() + space);
((result += xs), ...);
}
template <typename... Ts>
[[nodiscard]] auto concat(const Ts&... xs)
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
{
std::string result;
concatInto(result, xs...);
return result;
}
// Function for types convertible to std::string_view
template <typename Sep, typename First, typename... Rest>
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
-> std::enable_if_t<AllConvertibleToStringView<First, Rest...>, std::string>
{
std::string result;
result += first;
((result += sep, result += rest), ...);
return result;
}
// Function for other types
template <typename Sep, typename First, typename... Rest>
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
-> std::enable_if_t<!AllConvertibleToStringView<First, Rest...>, std::string>
{
using ::operator<<;
thread_local std::ostringstream oss;
oss.str("");
oss << first;
((oss << sep << rest), ...);
return oss.str();
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename Type>
CK_TILE_HOST void reference_batched_transpose(const HostTensor<Type>& x,
HostTensor<Type>& y,
std::string layout_in = "NCHW",
std::string layout_out = "NHWC")
{
const int N = x.mDesc.get_lengths()[0];
auto f = [&](auto batch) {
if(layout_in == "NCHW" && layout_out == "NHWC")
{
const int C = x.mDesc.get_lengths()[1];
const int H = x.mDesc.get_lengths()[2];
const int W = x.mDesc.get_lengths()[3];
for(int c = 0; c < C; ++c)
{
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
Type v_x = x(batch, c, h, w);
y(batch, h, w, c) = v_x;
}
}
}
}
else if(layout_in == "NHWC" && layout_out == "NCHW")
{
const int H = x.mDesc.get_lengths()[1];
const int W = x.mDesc.get_lengths()[2];
const int C = x.mDesc.get_lengths()[3];
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
for(int c = 0; c < C; ++c)
{
Type v_x = x(batch, h, w, c);
y(batch, c, h, w) = v_x;
}
}
}
}
};
make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
...@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k ? col * strideB + k
: k * strideB + col; : k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]); acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
} }
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>) int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col ? row * strideC + col
: col * strideC + row; : col * strideC + row;
C[c_index] = acc; C[c_index] = ck_tile::type_convert<CDataType>(acc);
} }
} }
......
...@@ -14,12 +14,15 @@ namespace ck_tile { ...@@ -14,12 +14,15 @@ namespace ck_tile {
template <typename WeightType, typename IndexType = index_t> template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights, const HostTensor<WeightType>& weights,
const HostTensor<IndexType>& local_expert_mask,
HostTensor<IndexType>& p_sorted_token_ids, HostTensor<IndexType>& p_sorted_token_ids,
HostTensor<WeightType>& sorted_weight, HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids, HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt, index_t& unit_cnt,
const index_t experts, const index_t experts,
const index_t unit_size) const index_t unit_size,
bool local_expert_masking,
bool skip_experts_with_zero_token = true)
{ {
const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1]; const index_t topk = topk_ids.mDesc.get_lengths()[1];
...@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif #endif
std::vector<std::vector<WeightType>> expert_token_weights( std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0)); experts, std::vector<WeightType>(unit_size, 0));
// count number of unit-size slices in this expert
std::vector<IndexType> expert_slices(experts, 1); std::vector<IndexType> expert_slices(experts, 1);
// count the tokens used in this expert
std::vector<IndexType> expert_slice_idxs(experts, 0); std::vector<IndexType> expert_slice_idxs(experts, 0);
// TODO: above 2 buffer seems duplicated
for(index_t t = 0; t < num_token; t++) for(index_t t = 0; t < num_token; t++)
{ {
...@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType* out_tokens = p_sorted_token_ids.data(); IndexType* out_tokens = p_sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data(); WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data(); IndexType* out_expert_id = sorted_expert_ids.data();
int curr_expert_id = 0;
for(index_t e = 0; e < experts; e++) for(index_t e = 0; e < experts; e++)
{ {
if(local_expert_masking)
{
if(local_expert_mask(e) == 0)
continue;
}
if(skip_experts_with_zero_token)
{
if(expert_slice_idxs[e] == 0)
{
curr_expert_id++;
continue;
}
}
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size; out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights, memcpy(out_weights,
...@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for(index_t s = 0; s < expert_slices[e]; s++) for(index_t s = 0; s < expert_slices[e]; s++)
{ {
out_expert_id[s] = e; out_expert_id[s] = curr_expert_id;
unit_cnt++; unit_cnt++;
} }
out_expert_id += expert_slices[e]; out_expert_id += expert_slices[e];
curr_expert_id++;
} }
unit_cnt *= unit_size; unit_cnt *= unit_size;
return; return;
......
...@@ -10,3 +10,4 @@ ...@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct BatchedTransposeHostArgs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
// index_t dim_blocks;
index_t dim_stride;
index_t dim_block_h;
index_t dim_block_w;
};
template <typename Pipeline_>
struct BatchedTransposeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::InputType;
struct BatchedTransposeKargs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
index_t dim_stride;
};
using Kargs = BatchedTransposeKargs;
using Hargs = BatchedTransposeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
size_t grid_size_z = h.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.batch = h.batch;
k.height = h.height;
k.width = h.width;
k.dim_stride = h.dim_stride;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
static_assert(kMPerThread == 1 && kNPerThread == 1);
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
make_tuple(kargs.height, kargs.width),
make_tuple(kargs.width, 1),
number<kNPerThread>{}, // TODO thread load value
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
make_tuple(kargs.width, kargs.height),
make_tuple(kargs.height, 1),
number<kMPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
sequence<kPadN, kPadM>{});
}();
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
auto y_block_window =
make_tile_window(y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
Pipeline{}(x_block_window, y_block_window);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment