Commit 60b885ae authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge branch 'develop' into andriy/lwpck-2788

parents 1a90f021 fd7600ce
......@@ -156,9 +156,9 @@ message("checking which targets are supported")
if(NOT ENABLE_ASAN_PACKAGING)
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
set(CK_GPU_TARGETS "gfx950")
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else()
set(CK_GPU_TARGETS "gfx950")
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
endif()
else()
#build CK only for xnack-supported targets when using ASAN
......@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
set(CK_USE_NATIVE_MX_SUPPORT "ON")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
......
......@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
endif()
......@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
......@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
......@@ -608,14 +608,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size,
B1K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size;
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -773,14 +773,10 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
......
......@@ -628,14 +628,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatAB,
......
......@@ -890,13 +890,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
typename additional_type = base_type,
bool is_single_rate_mfma = false>
struct MfmaSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
static constexpr auto GetMfma();
template <>
......@@ -960,7 +962,7 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<half_t, 32, 32>()
constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16f16;
......@@ -968,9 +970,14 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x8f16;
#endif
}
template <>
constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
constexpr auto GetMfma<half_t, 16, 16>()
constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32f16;
......@@ -979,6 +986,12 @@ struct MfmaSelector
#endif
}
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}
template <>
constexpr auto GetMfma<half_t, 16, 64>()
{
......@@ -998,7 +1011,7 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bhalf_t, 32, 32>()
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16bf16;
......@@ -1010,7 +1023,17 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16>()
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32bf16;
......@@ -1021,6 +1044,16 @@ struct MfmaSelector
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}
#if defined(__gfx950__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
......@@ -1104,8 +1137,8 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
static constexpr auto selected_mfma = mfma_type<
GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
__host__ __device__ constexpr MfmaSelector()
{
......@@ -1407,7 +1440,13 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto
mfma = MfmaSelector < base_type,
MPerXdlops, NPerXdlops, additional_type,
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4)
? true
: false > {};
static constexpr auto mfma_instr = mfma.selected_mfma;
......
......@@ -24,8 +24,9 @@ struct f4x2_pk_t
f4x2_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline type unpack() const
__host__ __device__ inline type unpack(Number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
return data & 0b00001111;
else
......@@ -38,6 +39,270 @@ struct f4x2_pk_t
}
};
struct f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
// custom data type - pack int4 data
struct pk_i4_t
{
......@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool, f4_t, f6_t, bf6_t
template <typename T>
inline constexpr bool is_native_type()
{
......@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
......@@ -1499,6 +1789,63 @@ struct non_native_vector_base<
}
};
// implementation for f6x16 and f6x32
template <typename T, index_t N>
struct non_native_vector_base<T, N, std::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
{
using data_t =
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
using element_t = typename T::element_type; // select element_t based on declared element type
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
static constexpr size_t size_factor =
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a)
: data_{data_v(a.At(Number<0>{}))}
{
}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
......@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
// f6
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
......@@ -6,11 +6,21 @@
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace ck {
// Declare a template function for scaled conversion
template <typename Y, typename X>
#if CK_USE_OCP_FP8
__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#else
__host__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x);
#endif
// convert f8_ocp_t to fp32
template <>
......@@ -200,27 +210,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
return out.float_1x32;
}
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert fp32 to fp8
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#else
inline __host__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8_ocp_t>(x, type_convert<float>(scale));
......@@ -231,8 +227,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
// convert fp32 to bf8
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale,
float x)
#else
inline __host__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_bexp_t scale, float x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8_ocp_t>(x, type_convert<float>(scale));
......@@ -243,8 +243,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
// convert fp32x2 to fp8x2
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#else
inline __host__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(e8m0_bexp_t scale, float2_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x2_ocp_t>(x, type_convert<float>(scale));
......@@ -254,8 +258,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
}
// convert fp32x2 to bf8x2
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#else
inline __host__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t>(e8m0_bexp_t scale,
float2_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x2_ocp_t>(x, type_convert<float>(scale));
......@@ -267,8 +276,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x16_ocp_t
scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ f8x16_ocp_t scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x16_ocp_t>(x, type_convert<float>(scale));
......@@ -280,8 +294,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x16_ocp_t
scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
#else
inline __host__ bf8x16_ocp_t scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale,
float16_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x16_ocp_t>(x, type_convert<float>(scale));
......@@ -293,8 +312,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ f8x32_ocp_t
scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ f8x32_ocp_t scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<f8x32_ocp_t>(x, type_convert<float>(scale));
......@@ -306,8 +330,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template <>
#if CK_USE_OCP_FP8
inline __host__ __device__ bf8x32_ocp_t
scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#else
inline __host__ bf8x32_ocp_t scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return mxf8_convert_sr<bf8x32_ocp_t>(x, type_convert<float>(scale));
......@@ -316,6 +345,26 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template <>
inline __host__ __device__ float scaled_type_convert<float, f4_t>(e8m0_bexp_t scale, f4_t x)
{
#if defined(__gfx950__)
union
{
float float_array[2];
float2_t float2_array;
} float_values{};
float_values.float2_array =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert<float>(scale), 0);
return float_values.float_array[0];
#else
return utils::to_float<f4_t>(scale, x);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template <>
inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_bexp_t scale,
......@@ -330,9 +379,10 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
value.f4x2_array[0] = x;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else
float2_t ret{
utils::to_float<f4_t>(scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()),
utils::to_float<f4_t>(scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>())};
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret;
#endif
}
......@@ -467,72 +517,104 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale, f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
......@@ -584,8 +666,59 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
template <>
inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t scale, f6_t x)
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<f6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale,
f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<f6_t>(scale, in.f6_array[i]); });
return out.float_vector;
#endif
}
/**
......@@ -599,8 +732,59 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
template <>
inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t scale, bf6_t x)
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(scale, x);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale,
bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}(
[&](auto i) { out.float_array[i] = utils::to_float<bf6_t>(scale, in.bf6_array[i]); });
return out.float_vector;
#endif
}
/**
......@@ -624,6 +808,28 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template <>
inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x, type_convert<float>(scale));
#else
return f6_convert_rne(x, type_convert<float>(scale));
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
......@@ -645,4 +851,27 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template <>
inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x, type_convert<float>(scale));
#else
return bf6_convert_rne(x, type_convert<float>(scale));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
} // namespace ck
......@@ -1146,10 +1146,11 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
float scale = 1.0f;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else
float2_t ret{utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>()),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>())};
float2_t ret{
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
return ret;
#endif
}
......@@ -1285,103 +1286,103 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[0] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<0>());
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
NumericLimits<e8m0_bexp_t>::Binary_1(),
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<1>());
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
return float_values.float32_array;
#endif
......@@ -1399,8 +1400,59 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
*/
inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type<f6_t>(x / scale);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
#endif
}
/**
......@@ -1417,15 +1469,75 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
return out.f6_array[0];
#else
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
});
return out.f6_vector;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F4_CONVERSION flag,
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
......@@ -1435,7 +1547,28 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
template <>
inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x);
#else
return f6_convert_rne(x);
......@@ -1454,8 +1587,62 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
template <>
inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.f6_array[i]);
});
return out.float_vector;
#endif
}
/**
......@@ -1470,8 +1657,60 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
*/
inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
float16_t in1{x};
float16_t in2{};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type<bf6_t>(x / scale);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
#endif
}
/**
......@@ -1489,14 +1728,76 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
return out.bf6_array[0];
#else
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
union
{
float32_t float_vector;
float float_array[32];
} float_values{x};
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
});
return out.bf6_vector;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F4_CONVERSION is defined,
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
......@@ -1505,7 +1806,26 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
template <>
inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{
#if CK_USE_SR_F6_CONVERSION
return bf6_convert_sr(x);
#else
return bf6_convert_rne(x);
......@@ -1524,8 +1844,63 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
template <>
inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
union
{
float32_t float_vector;
float float_array[32];
} out{};
ck::static_for<0, 32, 1>{}([&](auto i) {
out.float_array[i] =
utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), in.bf6_array[i]);
});
return out.float_vector;
#endif
}
template <typename Y, typename X, std::size_t NumElems>
......
......@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
......@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
......@@ -771,4 +771,4 @@
#undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
......@@ -9,6 +9,8 @@
using ck::bf6_convert_rne;
using ck::bf6_convert_sr;
using ck::bf6_t;
using ck::bf6x16_pk_t;
using ck::bf6x32_pk_t;
using ck::e8m0_bexp_t;
using ck::Number;
using ck::scaled_type_convert;
......@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic)
scaled_type_convert<float>(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)),
abs_tol);
}
TEST(BF6, TestSize)
{
ASSERT_EQ(1, sizeof(bf6_t));
ASSERT_EQ(12, sizeof(bf6x16_pk_t));
ASSERT_EQ(24, sizeof(bf6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<bf6x32_pk_t, 1>));
}
TEST(BF6, TestAlignment)
{
ASSERT_EQ(1, alignof(bf6_t));
ASSERT_EQ(4, alignof(bf6x16_pk_t));
ASSERT_EQ(4, alignof(bf6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<bf6x32_pk_t, 1>));
}
// test vector of 1 bf6x16_pk_t, contains 16 bf6_t
TEST(BF6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {bf6_t(0b000000),
bf6_t(0b100000),
bf6_t(0b000001),
bf6_t(0b100001),
bf6_t(0b000010),
bf6_t(0b100010),
bf6_t(0b000011),
bf6_t(0b100011),
bf6_t(0b000100),
bf6_t(0b100100),
bf6_t(0b000101),
bf6_t(0b100101),
bf6_t(0b000110),
bf6_t(0b100110),
bf6_t(0b001011),
bf6_t(0b101011)};
// reference vector
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 bf6x16_pk_t, contains 32 bf6_t
TEST(BF6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {bf6_t(0b000000),
bf6_t(0b100000),
bf6_t(0b000001),
bf6_t(0b100001),
bf6_t(0b000010),
bf6_t(0b100010),
bf6_t(0b000011),
bf6_t(0b100011),
bf6_t(0b000100),
bf6_t(0b100100),
bf6_t(0b000101),
bf6_t(0b100101),
bf6_t(0b000110),
bf6_t(0b100110),
bf6_t(0b001011),
bf6_t(0b101011)};
test_vec[1] = {bf6_t(0b010000),
bf6_t(0b110000),
bf6_t(0b010001),
bf6_t(0b110001),
bf6_t(0b010010),
bf6_t(0b110010),
bf6_t(0b010011),
bf6_t(0b110011),
bf6_t(0b010100),
bf6_t(0b110100),
bf6_t(0b010101),
bf6_t(0b110101),
bf6_t(0b010110),
bf6_t(0b110110),
bf6_t(0b011011),
bf6_t(0b111011)};
// reference vector
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 bf6x32_pk_t, contains 32 bf6_t
TEST(BF6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {bf6_t(0b000000), bf6_t(0b100000), bf6_t(0b000001), bf6_t(0b100001),
bf6_t(0b000010), bf6_t(0b100010), bf6_t(0b000011), bf6_t(0b100011),
bf6_t(0b000100), bf6_t(0b100100), bf6_t(0b000101), bf6_t(0b100101),
bf6_t(0b000110), bf6_t(0b100110), bf6_t(0b001011), bf6_t(0b101011),
bf6_t(0b010000), bf6_t(0b110000), bf6_t(0b010001), bf6_t(0b110001),
bf6_t(0b010010), bf6_t(0b110010), bf6_t(0b010011), bf6_t(0b110011),
bf6_t(0b010100), bf6_t(0b110100), bf6_t(0b010101), bf6_t(0b110101),
bf6_t(0b010110), bf6_t(0b110110), bf6_t(0b011011), bf6_t(0b111011)};
// reference vector
vector_type<bf6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<bf6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
......@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1)
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1)
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
......@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2)
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2)
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
......@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4)
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4)
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
......@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8)
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8)
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
......@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16)
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16)
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
......@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32)
vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0);
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
......@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32)
vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(),
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1));
});
}
......@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t;
using ck::f6_convert_rne;
using ck::f6_convert_sr;
using ck::f6_t;
using ck::f6x16_pk_t;
using ck::f6x32_pk_t;
using ck::Number;
using ck::scaled_type_convert;
using ck::type_convert;
......@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic)
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)),
abs_tol);
}
TEST(FP6, TestSize)
{
ASSERT_EQ(1, sizeof(f6_t));
ASSERT_EQ(12, sizeof(f6x16_pk_t));
ASSERT_EQ(24, sizeof(f6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<f6x32_pk_t, 1>));
}
TEST(FP6, TestAlignment)
{
ASSERT_EQ(1, alignof(f6_t));
ASSERT_EQ(4, alignof(f6x16_pk_t));
ASSERT_EQ(4, alignof(f6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<f6x32_pk_t, 1>));
}
// test vector of 1 f6x16_pk_t, contains 16 f6_t
TEST(FP6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 f6x16_pk_t, contains 32 f6_t
TEST(FP6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
test_vec[1] = {f6_t(0b010000),
f6_t(0b110000),
f6_t(0b010001),
f6_t(0b110001),
f6_t(0b010010),
f6_t(0b110010),
f6_t(0b010011),
f6_t(0b110011),
f6_t(0b010100),
f6_t(0b110100),
f6_t(0b010101),
f6_t(0b110101),
f6_t(0b010110),
f6_t(0b110110),
f6_t(0b011011),
f6_t(0b111011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 f6x32_pk_t, contains 32 f6_t
TEST(FP6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {f6_t(0b000000), f6_t(0b100000), f6_t(0b000001), f6_t(0b100001),
f6_t(0b000010), f6_t(0b100010), f6_t(0b000011), f6_t(0b100011),
f6_t(0b000100), f6_t(0b100100), f6_t(0b000101), f6_t(0b100101),
f6_t(0b000110), f6_t(0b100110), f6_t(0b001011), f6_t(0b101011),
f6_t(0b010000), f6_t(0b110000), f6_t(0b010001), f6_t(0b110001),
f6_t(0b010010), f6_t(0b110010), f6_t(0b010011), f6_t(0b110011),
f6_t(0b010100), f6_t(0b110100), f6_t(0b010101), f6_t(0b110101),
f6_t(0b010110), f6_t(0b110110), f6_t(0b011011), f6_t(0b111011)};
// reference vector
vector_type<f6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment