Unverified Commit 47b3e10b authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into conv_dlops/quantization

parents c6683eea c10a6e82
...@@ -948,7 +948,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -948,7 +948,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -824,7 +824,19 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -824,7 +824,19 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -864,7 +864,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -864,7 +864,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", " << "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", " << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", " << "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">"; << getMaskingSpecializationString(MaskingSpec) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return true; return true;
} }
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static bool CheckDLayout()
{
static bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
}
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
...@@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return false; return false;
} }
// Check supported layouts
// A0 - Row
// B0 - Col
// D0s - Rows
// B1 - Row or Col
// D1s - Rows
// E1 - Row
if(!(is_same_v<tensor_layout::gemm::RowMajor, A0Layout> &&
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
is_same_v<tensor_layout::gemm::ColumnMajor,
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
D1sLayout,
NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
arg.b0_grid_desc_n_k_, arg.b0_grid_desc_n_k_,
arg.b1_grid_desc_n_k_, arg.b1_grid_desc_n_k_,
......
...@@ -777,7 +777,19 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -777,7 +777,19 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -822,7 +822,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -822,7 +822,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -956,7 +956,19 @@ struct ...@@ -956,7 +956,19 @@ struct
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -913,7 +913,19 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -913,7 +913,19 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -880,7 +880,17 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -880,7 +880,17 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -720,7 +720,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -720,7 +720,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -630,7 +630,19 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -630,7 +630,19 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1567,7 +1567,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1567,7 +1567,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1
<< ">"; << ">";
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
......
...@@ -1552,7 +1552,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1552,7 +1552,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"; << ">";
if constexpr(ConvBackwardDataSpecialization == if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
......
...@@ -862,7 +862,15 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO ...@@ -862,7 +862,15 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -559,7 +559,19 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -559,7 +559,19 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -669,7 +669,15 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -669,7 +669,15 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -680,6 +680,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -680,6 +680,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
......
...@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -547,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -547,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<< MPerXDL << ", " << MPerXDL << ", "
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">" << ">"
<< " NumPrefetch: " << " NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
......
...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">" << ">"
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment