Unverified Commit 24f6f4ab authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #284 from ROCm/lwpck-2619

Add FP6/BF6 vector type support
parents 5eb4c3d1 df1bad99
...@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx ...@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
add_definitions(-DCK_USE_FNUZ_FP8) add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON") set(CK_USE_FNUZ_FP8 "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
set(CK_USE_NATIVE_MX_SUPPORT "ON")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
......
...@@ -131,6 +131,10 @@ ...@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ #cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif #endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on // clang-format on
#endif // CK_CONFIG_H_IN #endif // CK_CONFIG_H_IN
...@@ -24,8 +24,9 @@ struct f4x2_pk_t ...@@ -24,8 +24,9 @@ struct f4x2_pk_t
f4x2_pk_t(type init) : data{init} {} f4x2_pk_t(type init) : data{init} {}
template <index_t I> template <index_t I>
__host__ __device__ inline type unpack() const __host__ __device__ inline type unpack(Number<I>) const
{ {
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0) if constexpr(I == 0)
return data & 0b00001111; return data & 0b00001111;
else else
...@@ -38,6 +39,270 @@ struct f4x2_pk_t ...@@ -38,6 +39,270 @@ struct f4x2_pk_t
} }
}; };
struct f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
// custom data type - pack int4 data // custom data type - pack int4 data
struct pk_i4_t struct pk_i4_t
{ {
...@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x) ...@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x)
} }
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool // native types: bool, f4_t, f6_t, bf6_t
template <typename T> template <typename T>
inline constexpr bool is_native_type() inline constexpr bool is_native_type()
{ {
...@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t> ...@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{ {
using type = f8_ocp_t::data_type; using type = f8_ocp_t::data_type;
}; };
template <> template <>
struct nnvb_data_t_selector<bf8_ocp_t> struct nnvb_data_t_selector<bf8_ocp_t>
{ {
using type = bf8_ocp_t::data_type; using type = bf8_ocp_t::data_type;
}; };
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
};
template <> template <>
struct nnvb_data_t_selector<pk_i4_t> struct nnvb_data_t_selector<pk_i4_t>
{ {
...@@ -1499,6 +1789,63 @@ struct non_native_vector_base< ...@@ -1499,6 +1789,63 @@ struct non_native_vector_base<
} }
}; };
// implementation for f6x16 and f6x32
template <typename T, index_t N>
struct non_native_vector_base<T, N, std::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
{
using data_t =
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
using element_t = typename T::element_type; // select element_t based on declared element type
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
static constexpr size_t size_factor =
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a)
: data_{data_v(a.At(Number<0>{}))}
{
}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
};
template <typename T, index_t N> template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>; struct scalar_type<non_native_vector_base<T, N>>;
...@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type; ...@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type; using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type; using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
// f6
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
// pack int4 // pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -6,11 +6,21 @@ ...@@ -6,11 +6,21 @@
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp" #include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace ck { namespace ck {
// Declare a template function for scaled conversion // Declare a template function for scaled conversion
template <typename Y, typename X> template <typename Y, typename X>
#if CK_USE_OCP_FP8
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x); __host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#else
__host__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#endif
// convert f8_ocp_t to fp32 // convert f8_ocp_t to fp32
template <> template <>
...@@ -200,27 +210,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp ...@@ -200,27 +210,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
return out.float_1x32; return out.float_1x32;
} }
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x) inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale));
...@@ -231,8 +227,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be ...@@ -231,8 +227,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale,
float x) float x)
#else
inline __host__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale));
...@@ -243,8 +243,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_ ...@@ -243,8 +243,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
// convert fp32x2 to fp8x2 // convert fp32x2 to fp8x2
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x) float2_t x)
#else
inline __host__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, float2_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale));
...@@ -254,8 +258,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>( ...@@ -254,8 +258,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
} }
// convert fp32x2 to bf8x2 // convert fp32x2 to bf8x2
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale, inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x) float2_t x)
#else
inline __host__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale));
...@@ -267,8 +276,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t ...@@ -267,8 +276,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
// convert fp32x16 to fp8x16 // convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x16_ocp_t inline __host__ __device__ f8x16_ocp_t
scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ f8x16_ocp_t scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale));
...@@ -280,8 +294,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) ...@@ -280,8 +294,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x16 to bf8x16 // convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x16_ocp_t inline __host__ __device__ bf8x16_ocp_t
scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ bf8x16_ocp_t scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale));
...@@ -293,8 +312,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x) ...@@ -293,8 +312,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x32 to fp8x32 // convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x32_ocp_t inline __host__ __device__ f8x32_ocp_t
scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ f8x32_ocp_t scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale));
...@@ -306,8 +330,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) ...@@ -306,8 +330,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
// convert fp32x32 to bf8x32 // convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options. // @note Host version gives compilation error. Requires extra compiler options.
template <> template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x32_ocp_t inline __host__ __device__ bf8x32_ocp_t
scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ bf8x32_ocp_t scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale)); return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale));
...@@ -316,6 +345,26 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x) ...@@ -316,6 +345,26 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif #endif
} }
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32 // convert vector of 2 fp4 to vector of 2 fp32
template <> template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale, inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
...@@ -330,9 +379,10 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b ...@@ -330,9 +379,10 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
value.f4x2_array[0] = x; value.f4x2_array[0] = x;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0); return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else #else
float2_t ret{ float2_t ret{utils::to_float<f4_t>(
utils::to_float<f4_t>(scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()), scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>())}; utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret; return ret;
#endif #endif
} }
...@@ -467,72 +517,104 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m ...@@ -467,72 +517,104 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
} f4_values{bit_cast<__uint128_t>(x)}; } f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop // TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array; return float_values.float32_array;
#endif #endif
...@@ -584,8 +666,59 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_ ...@@ -584,8 +666,59 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
template <> template <>
inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x) inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<f6_t>(scale, x); return utils::to_float<f6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale,
f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<f6_t>(scale, in.f6_array[i]); });
return out.float_vector;
#endif
} }
/** /**
...@@ -599,8 +732,59 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc ...@@ -599,8 +732,59 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
template <> template <>
inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x) inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(scale, x); return utils::to_float<bf6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale,
bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<bf6_t>(scale, in.bf6_array[i]); });
return out.float_vector;
#endif
} }
/** /**
...@@ -624,6 +808,28 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca ...@@ -624,6 +808,28 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
#endif #endif
} }
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template <>
inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
/** /**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified * @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale. * scale.
...@@ -645,4 +851,27 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s ...@@ -645,4 +851,27 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s
#endif #endif
} }
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template <>
inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
} // namespace ck } // namespace ck
...@@ -1146,10 +1146,11 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x) ...@@ -1146,10 +1146,11 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
float scale = 1.0f; float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else #else
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), float2_t ret{
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()), utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>())}; utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret; return ret;
#endif #endif
} }
...@@ -1285,103 +1286,103 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x) ...@@ -1285,103 +1286,103 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
// TODO: pack in a loop // TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>( float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>( float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>( float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>( float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>( float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>( float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>( float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>()); f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>( float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(), NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()); f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array; return float_values.float32_array;
#endif #endif
...@@ -1399,8 +1400,59 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x) ...@@ -1399,8 +1400,59 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
*/ */
inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f) inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type<f6_t>(x / scale); return utils::sat_convert_to_type<f6_t>(x / scale);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
#endif
} }
/** /**
...@@ -1417,15 +1469,75 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) ...@@ -1417,15 +1469,75 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng); return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
});
return out.f6_vector;
#endif
} }
/** /**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type * @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t). * (f6_t).
* *
* Depending on the CK_USE_SR_F4_CONVERSION flag, * Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding * the conversion uses stochastic rounding
* or round-to-nearest-even. * or round-to-nearest-even.
* *
...@@ -1435,7 +1547,28 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) ...@@ -1435,7 +1547,28 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
template <> template <>
inline __host__ __device__ f6_t type_convert<f6_t, float>(float x) inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x); return f6_convert_sr(x);
#else #else
return f6_convert_rne(x); return f6_convert_rne(x);
...@@ -1454,8 +1587,62 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x) ...@@ -1454,8 +1587,62 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, f6_t>(f6_t x) inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x); return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.f6_array[i]);
});
return out.float_vector;
#endif
} }
/** /**
...@@ -1470,8 +1657,60 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x) ...@@ -1470,8 +1657,60 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
*/ */
inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f) inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type<bf6_t>(x / scale); return utils::sat_convert_to_type<bf6_t>(x / scale);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
#endif
} }
/** /**
...@@ -1489,14 +1728,76 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) ...@@ -1489,14 +1728,76 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng); return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
});
return out.bf6_vector;
#endif
} }
/** /**
* @brief Specializes float-to-bf6_t conversion. * @brief Specializes float-to-bf6_t conversion.
* *
* Uses stochastic rounding if CK_USE_SR_F4_CONVERSION is defined, * Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even. * otherwise uses round-to-nearest-even.
* *
* @param x Input float value to convert. * @param x Input float value to convert.
...@@ -1505,7 +1806,26 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) ...@@ -1505,7 +1806,26 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
template <> template <>
inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x) inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
{ {
#if CK_USE_SR_F4_CONVERSION #if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x); return bf6_convert_sr(x);
#else #else
return bf6_convert_rne(x); return bf6_convert_rne(x);
...@@ -1524,8 +1844,63 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x) ...@@ -1524,8 +1844,63 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x) inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
{ {
// currently there is no native conversion instruction #if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x); return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.bf6_array[i]);
});
return out.float_vector;
#endif
} }
template <typename Y, typename X, std::size_t NumElems> template <typename Y, typename X, std::size_t NumElems>
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
using ck::bf6_convert_rne; using ck::bf6_convert_rne;
using ck::bf6_convert_sr; using ck::bf6_convert_sr;
using ck::bf6_t; using ck::bf6_t;
using ck::bf6x16_pk_t;
using ck::bf6x32_pk_t;
using ck::e8m0_bexp_t; using ck::e8m0_bexp_t;
using ck::Number; using ck::Number;
using ck::scaled_type_convert; using ck::scaled_type_convert;
...@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic) ...@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic)
scaled_type_convert<float>(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)), scaled_type_convert<float>(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)),
abs_tol); abs_tol);
} }
TEST(BF6, TestSize)
{
ASSERT_EQ(1, sizeof(bf6_t));
ASSERT_EQ(12, sizeof(bf6x16_pk_t));
ASSERT_EQ(24, sizeof(bf6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<bf6x32_pk_t, 1>));
}
TEST(BF6, TestAlignment)
{
ASSERT_EQ(1, alignof(bf6_t));
ASSERT_EQ(4, alignof(bf6x16_pk_t));
ASSERT_EQ(4, alignof(bf6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<bf6x32_pk_t, 1>));
}
// test vector of 1 bf6x16_pk_t, contains 16 bf6_t
TEST(BF6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {bf6_t(0b000000),
bf6_t(0b100000),
bf6_t(0b000001),
bf6_t(0b100001),
bf6_t(0b000010),
bf6_t(0b100010),
bf6_t(0b000011),
bf6_t(0b100011),
bf6_t(0b000100),
bf6_t(0b100100),
bf6_t(0b000101),
bf6_t(0b100101),
bf6_t(0b000110),
bf6_t(0b100110),
bf6_t(0b001011),
bf6_t(0b101011)};
// reference vector
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 bf6x16_pk_t, contains 32 bf6_t
TEST(BF6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {bf6_t(0b000000),
bf6_t(0b100000),
bf6_t(0b000001),
bf6_t(0b100001),
bf6_t(0b000010),
bf6_t(0b100010),
bf6_t(0b000011),
bf6_t(0b100011),
bf6_t(0b000100),
bf6_t(0b100100),
bf6_t(0b000101),
bf6_t(0b100101),
bf6_t(0b000110),
bf6_t(0b100110),
bf6_t(0b001011),
bf6_t(0b101011)};
test_vec[1] = {bf6_t(0b010000),
bf6_t(0b110000),
bf6_t(0b010001),
bf6_t(0b110001),
bf6_t(0b010010),
bf6_t(0b110010),
bf6_t(0b010011),
bf6_t(0b110011),
bf6_t(0b010100),
bf6_t(0b110100),
bf6_t(0b010101),
bf6_t(0b110101),
bf6_t(0b010110),
bf6_t(0b110110),
bf6_t(0b011011),
bf6_t(0b111011)};
// reference vector
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 bf6x32_pk_t, contains 32 bf6_t
TEST(BF6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {bf6_t(0b000000), bf6_t(0b100000), bf6_t(0b000001), bf6_t(0b100001),
bf6_t(0b000010), bf6_t(0b100010), bf6_t(0b000011), bf6_t(0b100011),
bf6_t(0b000100), bf6_t(0b100100), bf6_t(0b000101), bf6_t(0b100101),
bf6_t(0b000110), bf6_t(0b100110), bf6_t(0b001011), bf6_t(0b101011),
bf6_t(0b010000), bf6_t(0b110000), bf6_t(0b010001), bf6_t(0b110001),
bf6_t(0b010010), bf6_t(0b110010), bf6_t(0b010011), bf6_t(0b110011),
bf6_t(0b010100), bf6_t(0b110100), bf6_t(0b010101), bf6_t(0b110101),
bf6_t(0b010110), bf6_t(0b110110), bf6_t(0b011011), bf6_t(0b111011)};
// reference vector
vector_type<bf6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<bf6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
...@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1) ...@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1) ...@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2) ...@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2) ...@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4) ...@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4) ...@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8) ...@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8) ...@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16) ...@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16) ...@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32) ...@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32) ...@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t; ...@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t;
using ck::f6_convert_rne; using ck::f6_convert_rne;
using ck::f6_convert_sr; using ck::f6_convert_sr;
using ck::f6_t; using ck::f6_t;
using ck::f6x16_pk_t;
using ck::f6x32_pk_t;
using ck::Number; using ck::Number;
using ck::scaled_type_convert; using ck::scaled_type_convert;
using ck::type_convert; using ck::type_convert;
...@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic) ...@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic)
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)), scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)),
abs_tol); abs_tol);
} }
TEST(FP6, TestSize)
{
ASSERT_EQ(1, sizeof(f6_t));
ASSERT_EQ(12, sizeof(f6x16_pk_t));
ASSERT_EQ(24, sizeof(f6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<f6x32_pk_t, 1>));
}
TEST(FP6, TestAlignment)
{
ASSERT_EQ(1, alignof(f6_t));
ASSERT_EQ(4, alignof(f6x16_pk_t));
ASSERT_EQ(4, alignof(f6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<f6x32_pk_t, 1>));
}
// test vector of 1 f6x16_pk_t, contains 16 f6_t
TEST(FP6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 f6x16_pk_t, contains 32 f6_t
TEST(FP6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
test_vec[1] = {f6_t(0b010000),
f6_t(0b110000),
f6_t(0b010001),
f6_t(0b110001),
f6_t(0b010010),
f6_t(0b110010),
f6_t(0b010011),
f6_t(0b110011),
f6_t(0b010100),
f6_t(0b110100),
f6_t(0b010101),
f6_t(0b110101),
f6_t(0b010110),
f6_t(0b110110),
f6_t(0b011011),
f6_t(0b111011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 f6x32_pk_t, contains 32 f6_t
TEST(FP6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {f6_t(0b000000), f6_t(0b100000), f6_t(0b000001), f6_t(0b100001),
f6_t(0b000010), f6_t(0b100010), f6_t(0b000011), f6_t(0b100011),
f6_t(0b000100), f6_t(0b100100), f6_t(0b000101), f6_t(0b100101),
f6_t(0b000110), f6_t(0b100110), f6_t(0b001011), f6_t(0b101011),
f6_t(0b010000), f6_t(0b110000), f6_t(0b010001), f6_t(0b110001),
f6_t(0b010010), f6_t(0b110010), f6_t(0b010011), f6_t(0b110011),
f6_t(0b010100), f6_t(0b110100), f6_t(0b010101), f6_t(0b110101),
f6_t(0b010110), f6_t(0b110110), f6_t(0b011011), f6_t(0b111011)};
// reference vector
vector_type<f6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment