Unverified Commit ddbeddb4 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-586

parents 91e611e8 c10a6e82
...@@ -114,15 +114,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -114,15 +114,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1)); make_tuple(I1, StrideB));
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB)); make_tuple(StrideB, I1));
} }
}(); }();
......
...@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
...@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{ {
assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
...@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<< MPerXDL << ", " << MPerXDL << ", "
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">" << ">"
<< " NumPrefetch: " << " NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
......
...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
...@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
return str.str(); return str.str();
......
...@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle ...@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << ", " << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
...@@ -719,7 +719,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -719,7 +719,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
......
...@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
return amd_buffer_load_impl<scalar_t, vector_size>( return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
...@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_store_impl<scalar_t, vector_size>( amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
...@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>( amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
...@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<scalar_t, vector_size>( amd_buffer_atomic_max_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
......
...@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 ...@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
"0"(c0), "0"(c0),
"1"(c1)); "1"(c1));
#else #else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false); c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false); c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif #endif
} }
...@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, ...@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
"2"(c2), "2"(c2),
"3"(c3)); "3"(c3));
#else #else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false); c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false); c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false); c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false); c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif #endif
} }
...@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// Ranged input operand // Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{ {
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{ {
// * Inline assembly need to elimate the duplicated data load, compiler won't help you // * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them. // delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32( // amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
// AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> ...@@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) = reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> ...@@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> ...@@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) = reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) = reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a, neg_a,
...@@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}], reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp); clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> ...@@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> ...@@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> ...@@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> ...@@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) = reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a, neg_a,
...@@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp); clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment