Unverified Commit bd09b5c5 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Add fp8 @ bf8 gemm support and example (#933)

* Add f8 bf8 gemm example

* Add element-wise ops

* Add intrinsics

* Update reference calculation

* Add an additional type option for xdlops gemm

* Fix build process

* Add bf8 to buffer addressing

* Update blockwise op, split typeA and typeB

* Update for compatibility

* Uppdate naming to f8->fp8

* Update naming

* Format
parent 59dbb01f
...@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize, TileMathThreadGroupSize,
ABDataType, ABDataType,
ABDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -737,6 +737,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -737,6 +737,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
......
...@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -424,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -424,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -569,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -569,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatABAdjusted, FloatABAdjusted,
FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
......
...@@ -762,6 +762,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -762,6 +762,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeType, ComputeType,
ComputeType,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
......
...@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
......
...@@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -31,7 +31,9 @@ enum struct MfmaInstr ...@@ -31,7 +31,9 @@ enum struct MfmaInstr
mfma_i32_16x16x32i8, mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64, mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8, mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8 mfma_f32_16x16x32f8f8,
mfma_f32_32x32x16f8bf8,
mfma_f32_16x16x32f8bf8
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -502,10 +504,62 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -502,10 +504,62 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
}; };
#endif #endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
struct MfmaSelector struct MfmaSelector
{ {
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_> template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
...@@ -656,7 +710,22 @@ struct MfmaSelector ...@@ -656,7 +710,22 @@ struct MfmaSelector
} }
#endif #endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
#endif
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
{ {
...@@ -703,6 +772,7 @@ template <typename base_type, ...@@ -703,6 +772,7 @@ template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t KPack, index_t KPack,
typename additional_type = base_type,
bool TransposeC = false> bool TransposeC = false>
struct XdlopsGemm struct XdlopsGemm
{ {
...@@ -854,14 +924,18 @@ struct XdlopsGemm ...@@ -854,14 +924,18 @@ struct XdlopsGemm
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(
is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value is_same<base_type, int8_t>::value
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
|| is_same<base_type, f8_t>::value || is_same<base_type, f8_t>::value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value)
#endif #endif
, ,
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC) if constexpr(!TransposeC)
...@@ -957,7 +1031,7 @@ struct XdlopsGemm ...@@ -957,7 +1031,7 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
} }
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{}; static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
...@@ -1127,8 +1127,16 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,8 +1127,16 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
...@@ -1139,12 +1147,20 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1139,12 +1147,20 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#endif #endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>( return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#else #else
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
...@@ -1156,7 +1172,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1156,7 +1172,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#endif #endif
...@@ -1216,29 +1232,50 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1216,29 +1232,50 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data); src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>( amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
} }
else else
{ {
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(src_thread_data,
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); dst_wave_buffer_resource,
#if defined CK_ENABLE_FP8 dst_addr_shift +
dst_thread_addr_offset,
0);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, f8_t>::value) if constexpr(is_same<scalar_t, f8_t>::value)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if constexpr(is_same<scalar_t, bf8_t>::value)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{ {
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>( auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data); src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>( amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
...@@ -1248,7 +1285,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1248,7 +1285,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif #endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>( amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
} }
#endif #endif
} }
......
...@@ -419,5 +419,70 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -419,5 +419,70 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
} }
}; };
#endif #endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8;
template <>
struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8bf8;
template <>
struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -21,7 +21,8 @@ template <typename ADataType, ...@@ -21,7 +21,8 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputType = ADataType> typename ComputeTypeA = ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
// Argument // Argument
...@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ComputType v_a; ComputeTypeA v_a;
ComputType v_b; ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
......
...@@ -95,7 +95,7 @@ struct GeneratorTensor_2<int8_t> ...@@ -95,7 +95,7 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
struct GeneratorTensor_2<ck::f8_t> struct GeneratorTensor_2<ck::f8_t>
{ {
...@@ -143,7 +143,7 @@ struct GeneratorTensor_3<ck::bhalf_t> ...@@ -143,7 +143,7 @@ struct GeneratorTensor_3<ck::bhalf_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 #if defined CK_ENABLE_FP8
template <> template <>
struct GeneratorTensor_3<ck::f8_t> struct GeneratorTensor_3<ck::f8_t>
{ {
......
add_instance_library(device_grouped_gemm_fixed_nk_instance set(GROUPED_GEMM_FIXED_NK_INSTANCES)
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp)
) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp)
endif()
if((DTYPES MATCHES "int8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment