Commit bfa4c686 authored by carlushuang's avatar carlushuang
Browse files

refactor by remove gemm_k_spec to dynamic

parent 9cefc261
...@@ -92,7 +92,7 @@ struct DeviceConvFwdDynamicTunable ...@@ -92,7 +92,7 @@ struct DeviceConvFwdDynamicTunable
// bool use_c_local_buffer; // bool use_c_local_buffer;
// ConvolutionForwardSpecialization_t forward_spec; // ConvolutionForwardSpecialization_t forward_spec;
// ConvolutionForwardGemmKSpecialization_t gemm_k_spec; ConvolutionForwardGemmKSpecialization_t gemm_k_spec;
ConvolutionForwardBlockLoopOverSpecialization_t loop_over_spec; ConvolutionForwardBlockLoopOverSpecialization_t loop_over_spec;
}; };
......
...@@ -29,7 +29,6 @@ template <typename InDataType, ...@@ -29,7 +29,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -580,8 +579,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -580,8 +579,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
...@@ -591,8 +589,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -591,8 +589,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType, OutDataType,
...@@ -601,8 +598,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -601,8 +598,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
...@@ -804,10 +800,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -804,10 +800,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -922,7 +917,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -922,7 +917,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_KYXC" << "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -29,8 +29,6 @@ template <typename InDataType, ...@@ -29,8 +29,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
// ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -558,8 +556,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -558,8 +556,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8<
...@@ -569,8 +566,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -569,8 +566,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType, OutDataType,
...@@ -579,8 +575,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -579,8 +575,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
...@@ -781,10 +776,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -781,10 +776,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -902,7 +896,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -902,7 +896,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_KYXCK8" << "DFwdAvx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -28,7 +28,6 @@ template <typename InDataType, ...@@ -28,7 +28,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -550,8 +549,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -550,8 +549,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
...@@ -561,8 +559,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -561,8 +559,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType, OutDataType,
...@@ -571,8 +568,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -571,8 +568,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
...@@ -773,10 +769,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -773,10 +769,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -897,7 +892,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -897,7 +892,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_YXCK" << "DFwdAvx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
......
...@@ -31,7 +31,6 @@ template <typename InDataType, ...@@ -31,7 +31,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -596,8 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -596,8 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
...@@ -607,8 +605,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -607,8 +605,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
...@@ -855,10 +852,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -855,10 +852,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -981,7 +977,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -981,7 +977,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXC" << "DFwd_BAA_Avx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -31,7 +31,6 @@ template <typename InDataType, ...@@ -31,7 +31,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -573,8 +572,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -573,8 +572,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8<
...@@ -584,8 +582,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -584,8 +582,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
...@@ -832,10 +829,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -832,10 +829,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -961,7 +957,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -961,7 +957,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXCK8" << "DFwd_BAA_Avx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -30,7 +30,6 @@ template <typename InDataType, ...@@ -30,7 +30,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -569,8 +568,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -569,8 +568,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
...@@ -580,8 +578,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -580,8 +578,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
...@@ -828,10 +825,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -828,10 +825,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -960,7 +956,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -960,7 +956,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_YXCK" << "DFwd_BAA_Avx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -352,7 +352,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -352,7 +352,8 @@ struct GridwiseGemmAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
...@@ -495,7 +496,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -495,7 +496,8 @@ struct GridwiseGemmAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
......
...@@ -378,7 +378,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -378,7 +378,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
...@@ -533,7 +534,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -533,7 +534,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
......
...@@ -210,6 +210,69 @@ void memcpy32_avx2_with_extra_2src(void* dst, ...@@ -210,6 +210,69 @@ void memcpy32_avx2_with_extra_2src(void* dst,
} }
} }
template <typename ElementwiseOp>
void memcpy32_avx2_with_extra_1src(void* dst,
const void* src,
const void* src_aux,
const ck::index_t n,
const ElementwiseOp& element_op)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
const float* p_src_aux = reinterpret_cast<const float*>(src_aux);
while(i_n >= 16)
{
_mm256_storeu_ps(
p_dst + 0,
element_op.Apply(_mm256_loadu_ps(p_src + 0), _mm256_loadu_ps(p_src_aux + 0)));
_mm256_storeu_ps(
p_dst + 8,
element_op.Apply(_mm256_loadu_ps(p_src + 8), _mm256_loadu_ps(p_src_aux + 8)));
p_dst += 16;
p_src += 16;
p_src_aux += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst,
element_op.Apply(_mm256_loadu_ps(p_src), _mm256_loadu_ps(p_src_aux)));
p_dst += 8;
p_src += 8;
p_src_aux += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, element_op.Apply(_mm_loadu_ps(p_src), _mm_loadu_ps(p_src_aux)));
p_dst += 4;
p_src += 4;
p_src_aux += 4;
}
if(i_n & 2)
{
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
__m128i s = _mm_loadu_si64(p_src);
__m128i s1 = _mm_loadu_si64(p_src_aux);
__m128 v =
element_op.Apply(*reinterpret_cast<__m128*>(&s), *reinterpret_cast<__m128*>(&s1));
_mm_storeu_si64(p_dst, *reinterpret_cast<__m128i*>(&v));
#else
_mm_storeu_si64(p_dst, element_op.Apply(_mm_loadu_si64(p_src), _mm_loadu_si64(p_src_aux)));
#endif
p_dst += 2;
p_src += 2;
p_src_aux += 2;
}
if(i_n & 1)
{
*p_dst = element_op.Apply(*p_src, *p_src_aux);
}
}
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n) inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
...@@ -324,8 +387,7 @@ template <typename SrcData, ...@@ -324,8 +387,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -336,8 +398,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -336,8 +398,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
const Index&, const Index&,
const DstDesc&, const DstDesc&,
const Index&, const Index&,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op,
: element_op_(element_op) const ConvolutionForwardGemmKSpecialization_t& gemm_k_spec)
: element_op_(element_op), gemm_k_spec_(gemm_k_spec)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -630,8 +693,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -630,8 +693,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if constexpr(GemmKSpecialization == if(gemm_k_spec_ == ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{ {
// c % k_per_block == 0, so every time k_per_block here is the same // c % k_per_block == 0, so every time k_per_block here is the same
ck::index_t i_m_itr = m_per_block; ck::index_t i_m_itr = m_per_block;
...@@ -782,8 +844,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -782,8 +844,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
} }
else else
{ {
if constexpr(GemmKSpecialization == if(gemm_k_spec_ == ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{ {
// TODO: branch seems weird // TODO: branch seems weird
...@@ -827,6 +888,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -827,6 +888,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
private: private:
const ElementwiseOperation element_op_; const ElementwiseOperation element_op_;
const ConvolutionForwardGemmKSpecialization_t gemm_k_spec_;
ck::index_t i_n; ck::index_t i_n;
ck::index_t i_c; ck::index_t i_c;
...@@ -875,8 +937,7 @@ template <typename SrcData, ...@@ -875,8 +937,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -1096,8 +1157,7 @@ template <typename SrcData, ...@@ -1096,8 +1157,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8 struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -1283,8 +1343,7 @@ template <typename SrcData, ...@@ -1283,8 +1343,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -1415,8 +1474,7 @@ template <typename SrcData, ...@@ -1415,8 +1474,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
......
...@@ -49,17 +49,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -49,17 +49,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
......
...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
......
...@@ -41,17 +41,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -41,17 +41,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
......
...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
......
...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
......
...@@ -41,17 +41,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -41,17 +41,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment