Unverified Commit 2bb3ad28 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-583

parents f9dfc99a a2d5ca8e
...@@ -680,6 +680,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -680,6 +680,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
......
...@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -114,15 +114,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -114,15 +114,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1)); make_tuple(I1, StrideB));
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB)); make_tuple(StrideB, I1));
} }
}(); }();
......
...@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{ {
assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
...@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{ {
assert(K % K1 == 0);
const index_t K0 = K / K1; const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
...@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<< MPerXDL << ", " << MPerXDL << ", "
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">" << ">"
<< " NumPrefetch: " << " NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
......
...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
...@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
return str.str(); return str.str();
......
...@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle ...@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << ", " << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -238,19 +238,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -238,19 +238,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
template <typename DsGridDesc_M0_M10_M11_N0_N10_N11, template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop,
typename Block2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const AElementwiseOperation&, const AElementwiseOperation&,
const BElementwiseOperation&, const BElementwiseOperation&,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
...@@ -399,8 +399,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -399,8 +399,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double =
static_cast<FloatAB*>(p_shared_block) + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
......
...@@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
...@@ -719,7 +719,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -719,7 +719,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
/*******************************************************************************/ /*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
......
...@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// Ranged input operand // Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{ {
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{ {
// * Inline assembly need to elimate the duplicated data load, compiler won't help you // * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them. // delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32( // amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
// AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> ...@@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float8_t>()(Number<0>{}) = reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> ...@@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> ...@@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) = reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) = reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a, neg_a,
...@@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}], reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp); clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> ...@@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> ...@@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> ...@@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> ...@@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) = reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
...@@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a, neg_a,
...@@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp); clamp);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment