Commit 60b885ae authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge branch 'develop' into andriy/lwpck-2788

parents 1a90f021 fd7600ce
...@@ -156,9 +156,9 @@ message("checking which targets are supported") ...@@ -156,9 +156,9 @@ message("checking which targets are supported")
if(NOT ENABLE_ASAN_PACKAGING) if(NOT ENABLE_ASAN_PACKAGING)
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
set(CK_GPU_TARGETS "gfx950") set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else() else()
set(CK_GPU_TARGETS "gfx950") set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
endif() endif()
else() else()
#build CK only for xnack-supported targets when using ASAN #build CK only for xnack-supported targets when using ASAN
...@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx ...@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
add_definitions(-DCK_USE_FNUZ_FP8) add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON") set(CK_USE_FNUZ_FP8 "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
set(CK_USE_NATIVE_MX_SUPPORT "ON")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
......
...@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
endif() endif()
...@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif() endif()
...@@ -131,6 +131,10 @@ ...@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ #cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif #endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on // clang-format on
#endif // CK_CONFIG_H_IN #endif // CK_CONFIG_H_IN
...@@ -608,14 +608,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -608,14 +608,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math::lcm( // selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size, //
B1K1), // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk); // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size;
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -773,14 +773,10 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -773,14 +773,10 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
FloatAB, FloatAB,
......
...@@ -628,14 +628,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -628,14 +628,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr index_t Gemm1KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
#else
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
#endif
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
FloatAB, FloatAB,
......
...@@ -890,13 +890,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4> ...@@ -890,13 +890,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template <typename base_type, template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
typename additional_type = base_type> typename additional_type = base_type,
bool is_single_rate_mfma = false>
struct MfmaSelector struct MfmaSelector
{ {
template <typename base_type_, template <typename base_type_,
index_t MPerXdlops_, index_t MPerXdlops_,
index_t NPerXdlops_, index_t NPerXdlops_,
typename additional_type_ = base_type_> typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
...@@ -960,7 +962,7 @@ struct MfmaSelector ...@@ -960,7 +962,7 @@ struct MfmaSelector
} }
template <> template <>
constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16f16; return MfmaInstr::mfma_f32_32x32x16f16;
...@@ -968,9 +970,14 @@ struct MfmaSelector ...@@ -968,9 +970,14 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
#endif #endif
} }
template <>
constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <> template <>
constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32f16; return MfmaInstr::mfma_f32_16x16x32f16;
...@@ -979,6 +986,12 @@ struct MfmaSelector ...@@ -979,6 +986,12 @@ struct MfmaSelector
#endif #endif
} }
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}
template <> template <>
constexpr auto GetMfma<half_t, 16, 64>() constexpr auto GetMfma<half_t, 16, 64>()
{ {
...@@ -998,7 +1011,7 @@ struct MfmaSelector ...@@ -998,7 +1011,7 @@ struct MfmaSelector
} }
template <> template <>
constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16bf16; return MfmaInstr::mfma_f32_32x32x16bf16;
...@@ -1010,7 +1023,17 @@ struct MfmaSelector ...@@ -1010,7 +1023,17 @@ struct MfmaSelector
} }
template <> template <>
constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32bf16; return MfmaInstr::mfma_f32_16x16x32bf16;
...@@ -1021,6 +1044,16 @@ struct MfmaSelector ...@@ -1021,6 +1044,16 @@ struct MfmaSelector
#endif #endif
} }
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}
#if defined(__gfx950__) #if defined(__gfx950__)
template <> template <>
constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
...@@ -1104,8 +1137,8 @@ struct MfmaSelector ...@@ -1104,8 +1137,8 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
static constexpr auto selected_mfma = static constexpr auto selected_mfma = mfma_type<
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{}; GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
{ {
...@@ -1407,7 +1440,13 @@ struct XdlopsGemm ...@@ -1407,7 +1440,13 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
} }
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{}; // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto
mfma = MfmaSelector < base_type,
MPerXdlops, NPerXdlops, additional_type,
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4)
? true
: false > {};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
...@@ -24,8 +24,9 @@ struct f4x2_pk_t ...@@ -24,8 +24,9 @@ struct f4x2_pk_t
f4x2_pk_t(type init) : data{init} {} f4x2_pk_t(type init) : data{init} {}
template <index_t I> template <index_t I>
__host__ __device__ inline type unpack() const __host__ __device__ inline type unpack(Number<I>) const
{ {
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0) if constexpr(I == 0)
return data & 0b00001111; return data & 0b00001111;
else else
...@@ -38,6 +39,270 @@ struct f4x2_pk_t ...@@ -38,6 +39,270 @@ struct f4x2_pk_t
} }
}; };
struct f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
// custom data type - pack int4 data // custom data type - pack int4 data
struct pk_i4_t struct pk_i4_t
{ {
...@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x) ...@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x)
} }
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool // native types: bool, f4_t, f6_t, bf6_t
template <typename T> template <typename T>
inline constexpr bool is_native_type() inline constexpr bool is_native_type()
{ {
...@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t> ...@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{ {
using type = f8_ocp_t::data_type; using type = f8_ocp_t::data_type;
}; };
template <> template <>
struct nnvb_data_t_selector<bf8_ocp_t> struct nnvb_data_t_selector<bf8_ocp_t>
{ {
using type = bf8_ocp_t::data_type; using type = bf8_ocp_t::data_type;
}; };
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
};
template <> template <>
struct nnvb_data_t_selector<pk_i4_t> struct nnvb_data_t_selector<pk_i4_t>
{ {
...@@ -1499,6 +1789,63 @@ struct non_native_vector_base< ...@@ -1499,6 +1789,63 @@ struct non_native_vector_base<
} }
}; };
// implementation for f6x16 and f6x32
template <typename T, index_t N>
struct non_native_vector_base<T, N, std::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
{
using data_t =
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
using element_t = typename T::element_type; // select element_t based on declared element type
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
static constexpr size_t size_factor =
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a)
: data_{data_v(a.At(Number<0>{}))}
{
}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
};
template <typename T, index_t N> template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>; struct scalar_type<non_native_vector_base<T, N>>;
...@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type; ...@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type; using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type; using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
// f6
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
// pack int4 // pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
This diff is collapsed.
This diff is collapsed.
...@@ -824,4 +824,4 @@ ...@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -722,4 +722,4 @@ ...@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -771,4 +771,4 @@ ...@@ -771,4 +771,4 @@
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B #undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
using ck::bf6_convert_rne; using ck::bf6_convert_rne;
using ck::bf6_convert_sr; using ck::bf6_convert_sr;
using ck::bf6_t; using ck::bf6_t;
using ck::bf6x16_pk_t;
using ck::bf6x32_pk_t;
using ck::e8m0_bexp_t; using ck::e8m0_bexp_t;
using ck::Number; using ck::Number;
using ck::scaled_type_convert; using ck::scaled_type_convert;
...@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic) ...@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic)
scaled_type_convert<float>(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)), scaled_type_convert<float>(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)),
abs_tol); abs_tol);
} }
TEST(BF6, TestSize)
{
ASSERT_EQ(1, sizeof(bf6_t));
ASSERT_EQ(12, sizeof(bf6x16_pk_t));
ASSERT_EQ(24, sizeof(bf6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<bf6x32_pk_t, 1>));
}
TEST(BF6, TestAlignment)
{
ASSERT_EQ(1, alignof(bf6_t));
ASSERT_EQ(4, alignof(bf6x16_pk_t));
ASSERT_EQ(4, alignof(bf6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<bf6x32_pk_t, 1>));
}
// test vector of 1 bf6x16_pk_t, contains 16 bf6_t
TEST(BF6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {bf6_t(0b000000),
bf6_t(0b100000),
bf6_t(0b000001),
bf6_t(0b100001),
bf6_t(0b000010),
bf6_t(0b100010),
bf6_t(0b000011),
bf6_t(0b100011),
bf6_t(0b000100),
bf6_t(0b100100),
bf6_t(0b000101),
bf6_t(0b100101),
bf6_t(0b000110),
bf6_t(0b100110),
bf6_t(0b001011),
bf6_t(0b101011)};
// reference vector
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 bf6x16_pk_t, contains 32 bf6_t
TEST(BF6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {bf6_t(0b000000),
bf6_t(0b100000),
bf6_t(0b000001),
bf6_t(0b100001),
bf6_t(0b000010),
bf6_t(0b100010),
bf6_t(0b000011),
bf6_t(0b100011),
bf6_t(0b000100),
bf6_t(0b100100),
bf6_t(0b000101),
bf6_t(0b100101),
bf6_t(0b000110),
bf6_t(0b100110),
bf6_t(0b001011),
bf6_t(0b101011)};
test_vec[1] = {bf6_t(0b010000),
bf6_t(0b110000),
bf6_t(0b010001),
bf6_t(0b110001),
bf6_t(0b010010),
bf6_t(0b110010),
bf6_t(0b010011),
bf6_t(0b110011),
bf6_t(0b010100),
bf6_t(0b110100),
bf6_t(0b010101),
bf6_t(0b110101),
bf6_t(0b010110),
bf6_t(0b110110),
bf6_t(0b011011),
bf6_t(0b111011)};
// reference vector
vector_type<bf6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x16_pk_t>()(Number<i>{}) = bf6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<bf6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<bf6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<bf6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 bf6x32_pk_t, contains 32 bf6_t
TEST(BF6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {bf6_t(0b000000), bf6_t(0b100000), bf6_t(0b000001), bf6_t(0b100001),
bf6_t(0b000010), bf6_t(0b100010), bf6_t(0b000011), bf6_t(0b100011),
bf6_t(0b000100), bf6_t(0b100100), bf6_t(0b000101), bf6_t(0b100101),
bf6_t(0b000110), bf6_t(0b100110), bf6_t(0b001011), bf6_t(0b101011),
bf6_t(0b010000), bf6_t(0b110000), bf6_t(0b010001), bf6_t(0b110001),
bf6_t(0b010010), bf6_t(0b110010), bf6_t(0b010011), bf6_t(0b110011),
bf6_t(0b010100), bf6_t(0b110100), bf6_t(0b010101), bf6_t(0b110101),
bf6_t(0b010110), bf6_t(0b110110), bf6_t(0b011011), bf6_t(0b111011)};
// reference vector
vector_type<bf6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<bf6x32_pk_t>()(Number<i>{}) = bf6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<bf6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<bf6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<bf6_t>(test_vec[static_cast<int>(i)]));
});
}
...@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1) ...@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1) ...@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2) ...@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2) ...@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4) ...@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4) ...@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8) ...@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8) ...@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16) ...@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16) ...@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32) ...@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32)
vector_type<f4x2_pk_t, size> right_vec; vector_type<f4x2_pk_t, size> right_vec;
// check default CTOR // check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), 0); ASSERT_EQ(
ASSERT_EQ(right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), 0); right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}), 0);
ASSERT_EQ(
right_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}), 0);
}); });
// assign test values to the vector // assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
...@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32) ...@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32)
vector_type<f4x2_pk_t, size> left_vec{right_vec}; vector_type<f4x2_pk_t, size> left_vec{right_vec};
// check if values were copied correctly // check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) { ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<0>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<0>{}),
test_vec.at(i)); test_vec.at(i));
ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<1>(), ASSERT_EQ(left_vec.template AsType<f4x2_pk_t>()(Number<i>{}).template unpack<>(Number<1>{}),
test_vec.at(i + 1)); test_vec.at(i + 1));
}); });
} }
...@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t; ...@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t;
using ck::f6_convert_rne; using ck::f6_convert_rne;
using ck::f6_convert_sr; using ck::f6_convert_sr;
using ck::f6_t; using ck::f6_t;
using ck::f6x16_pk_t;
using ck::f6x32_pk_t;
using ck::Number; using ck::Number;
using ck::scaled_type_convert; using ck::scaled_type_convert;
using ck::type_convert; using ck::type_convert;
...@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic) ...@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic)
scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)), scaled_type_convert<float>(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)),
abs_tol); abs_tol);
} }
TEST(FP6, TestSize)
{
ASSERT_EQ(1, sizeof(f6_t));
ASSERT_EQ(12, sizeof(f6x16_pk_t));
ASSERT_EQ(24, sizeof(f6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<f6x32_pk_t, 1>));
}
TEST(FP6, TestAlignment)
{
ASSERT_EQ(1, alignof(f6_t));
ASSERT_EQ(4, alignof(f6x16_pk_t));
ASSERT_EQ(4, alignof(f6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<f6x32_pk_t, 1>));
}
// test vector of 1 f6x16_pk_t, contains 16 f6_t
TEST(FP6, TestAsType16x1)
{
// test size
const int vector_size = 1;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x16_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
// test vector of 2 f6x16_pk_t, contains 32 f6_t
TEST(FP6, TestAsType16x2)
{
// test size
const int vector_size = 2;
const int packed_size = 16;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
test_vec_t test_vec[2];
test_vec[0] = {f6_t(0b000000),
f6_t(0b100000),
f6_t(0b000001),
f6_t(0b100001),
f6_t(0b000010),
f6_t(0b100010),
f6_t(0b000011),
f6_t(0b100011),
f6_t(0b000100),
f6_t(0b100100),
f6_t(0b000101),
f6_t(0b100101),
f6_t(0b000110),
f6_t(0b100110),
f6_t(0b001011),
f6_t(0b101011)};
test_vec[1] = {f6_t(0b010000),
f6_t(0b110000),
f6_t(0b010001),
f6_t(0b110001),
f6_t(0b010010),
f6_t(0b110010),
f6_t(0b010011),
f6_t(0b110011),
f6_t(0b010100),
f6_t(0b110100),
f6_t(0b010101),
f6_t(0b110101),
f6_t(0b010110),
f6_t(0b110110),
f6_t(0b011011),
f6_t(0b111011)};
// reference vector
vector_type<f6x16_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(right_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
0);
});
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x16_pk_t>()(Number<i>{}) = f6x16_pk_t{}.pack(test_vec[i]);
});
// copy the vector
vector_type<f6x16_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) {
ck::static_for<0, packed_size, 1>{}([&](auto idx_element) {
ASSERT_EQ(left_vec.template AsType<f6x16_pk_t>()(Number<idx_vector>{})
.template unpack<>(Number<idx_element>{}),
static_cast<f6_t>(test_vec[idx_vector][static_cast<int>(idx_element)]));
});
});
}
// test vector of 1 f6x32_pk_t, contains 32 f6_t
TEST(FP6, TestAsType32x1)
{
// test size
const int vector_size = 1;
const int packed_size = 32;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
test_vec_t test_vec = {f6_t(0b000000), f6_t(0b100000), f6_t(0b000001), f6_t(0b100001),
f6_t(0b000010), f6_t(0b100010), f6_t(0b000011), f6_t(0b100011),
f6_t(0b000100), f6_t(0b100100), f6_t(0b000101), f6_t(0b100101),
f6_t(0b000110), f6_t(0b100110), f6_t(0b001011), f6_t(0b101011),
f6_t(0b010000), f6_t(0b110000), f6_t(0b010001), f6_t(0b110001),
f6_t(0b010010), f6_t(0b110010), f6_t(0b010011), f6_t(0b110011),
f6_t(0b010100), f6_t(0b110100), f6_t(0b010101), f6_t(0b110101),
f6_t(0b010110), f6_t(0b110110), f6_t(0b011011), f6_t(0b111011)};
// reference vector
vector_type<f6x32_pk_t, vector_size> right_vec;
// check default CTOR
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
right_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}), 0);
});
// assign test values to the vector
ck::static_for<0, vector_size, 1>{}([&](auto i) {
right_vec.template AsType<f6x32_pk_t>()(Number<i>{}) = f6x32_pk_t{}.pack(test_vec);
});
// copy the vector
vector_type<f6x32_pk_t, vector_size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, packed_size, 1>{}([&](auto i) {
ASSERT_EQ(
left_vec.template AsType<f6x32_pk_t>()(Number<0>{}).template unpack<>(Number<i>{}),
static_cast<f6_t>(test_vec[static_cast<int>(i)]));
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment