Commit bfa4c686 authored by carlushuang's avatar carlushuang
Browse files

refactor by remove gemm_k_spec to dynamic

parent 9cefc261
...@@ -92,7 +92,7 @@ struct DeviceConvFwdDynamicTunable ...@@ -92,7 +92,7 @@ struct DeviceConvFwdDynamicTunable
// bool use_c_local_buffer; // bool use_c_local_buffer;
// ConvolutionForwardSpecialization_t forward_spec; // ConvolutionForwardSpecialization_t forward_spec;
// ConvolutionForwardGemmKSpecialization_t gemm_k_spec; ConvolutionForwardGemmKSpecialization_t gemm_k_spec;
ConvolutionForwardBlockLoopOverSpecialization_t loop_over_spec; ConvolutionForwardBlockLoopOverSpecialization_t loop_over_spec;
}; };
......
...@@ -29,7 +29,6 @@ template <typename InDataType, ...@@ -29,7 +29,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -580,8 +579,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -580,8 +579,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
...@@ -591,8 +589,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -591,8 +589,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType, OutDataType,
...@@ -601,8 +598,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -601,8 +598,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
...@@ -804,10 +800,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -804,10 +800,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -922,7 +917,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -922,7 +917,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_KYXC" << "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -29,8 +29,6 @@ template <typename InDataType, ...@@ -29,8 +29,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
// ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -558,8 +556,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -558,8 +556,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8<
...@@ -569,8 +566,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -569,8 +566,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType, OutDataType,
...@@ -579,8 +575,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -579,8 +575,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
...@@ -781,10 +776,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -781,10 +776,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -902,7 +896,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -902,7 +896,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_KYXCK8" << "DFwdAvx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP #ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP #define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#include "device.hpp" #include "device.hpp"
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp" #include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp" #include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp" #include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp" #include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp" #include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ck::index_t NumDimSpatial,
ck::index_t NumDimSpatial, ck::index_t MPerThread,
ck::index_t MPerThread, ck::index_t NPerThread,
ck::index_t NPerThread, bool UseALocalBuffer,
bool UseALocalBuffer, bool UseBLocalBuffer,
bool UseBLocalBuffer, bool UseCLocalBuffer>
bool UseCLocalBuffer> struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> {
{ using DeviceOp = DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using DeviceOp = DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using ADataType = InDataType; using BDataType = WeiDataType;
using BDataType = WeiDataType; using CDataType = OutDataType;
using CDataType = OutDataType;
using AElementwiseOperation = InElementwiseOperation;
using AElementwiseOperation = InElementwiseOperation; using BElementwiseOperation = WeiElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation; using CElementwiseOperation = OutElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
// TODO make A/B datatype different using ABDataType = InDataType;
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr bool NonTemporalStore = false;
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K(
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K( const DeviceConvFwdDynamicTunable& dtune)
const DeviceConvFwdDynamicTunable& dtune) : gridwise_gemm(dtune)
: gridwise_gemm(dtune) {
{ }
}
static constexpr auto GetThreadwiseGemm_Dispatch()
static constexpr auto GetThreadwiseGemm_Dispatch() {
{ if constexpr(MPerThread == 4 && NPerThread == 24)
if constexpr(MPerThread == 4 && NPerThread == 24) {
{ return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType,
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType, WeiDataType,
WeiDataType, OutDataType,
OutDataType, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, NonTemporalStore>{};
NonTemporalStore>{}; }
} else if constexpr(MPerThread == 6 && NPerThread == 16)
else if constexpr(MPerThread == 6 && NPerThread == 16) {
{ return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType,
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType, WeiDataType,
WeiDataType, OutDataType,
OutDataType, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, NonTemporalStore>{};
NonTemporalStore>{}; }
} else
else {
{ // static_assert(false, "invalid Mr/Nr");
// static_assert(false, "invalid Mr/Nr"); }
} }
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) {
{ return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n)); }
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n) {
{ const auto out_gemm_m_n_grid_desc =
const auto out_gemm_m_n_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
return out_gemm_m_n_grid_desc; }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> static auto GetInputTensorDescriptor(ck::index_t N,
static auto GetInputTensorDescriptor(ck::index_t N, ck::index_t C,
ck::index_t C, ck::index_t gemm_m,
ck::index_t gemm_m, ck::index_t gemm_k,
ck::index_t gemm_k, const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& input_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_strides, const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& conv_filter_dilations, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_right_pads)
const std::vector<ck::index_t>& input_right_pads) {
{ const index_t Wi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[0]; const index_t Wo = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0]; const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) {
{ const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Pad0) {
{ const auto in_n_wi_c_grid_desc =
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( in_n_wi_c_grid_desc,
in_n_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( in_n_wo_c_grid_desc,
in_n_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else
else {
{ const index_t X = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0]; const index_t ConvDilationW = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[0]; const index_t InRightPadW = input_right_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( in_n_wi_c_grid_desc,
in_n_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_wip_c_grid_desc,
in_n_wip_c_grid_desc, make_tuple(
make_tuple( make_pass_through_transform(N),
make_pass_through_transform(N), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
transform_tensor_descriptor(in_n_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_merge_transform(make_tuple(X, C))),
make_merge_transform(make_tuple(X, C))), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> static auto GetInputTensorDescriptor(ck::index_t N,
static auto GetInputTensorDescriptor(ck::index_t N, ck::index_t C,
ck::index_t C, ck::index_t gemm_m,
ck::index_t gemm_m, ck::index_t gemm_k,
ck::index_t gemm_k, const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& input_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_strides, const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& conv_filter_dilations, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_right_pads)
const std::vector<ck::index_t>& input_right_pads) {
{ const index_t Hi = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[0]; const index_t Wi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[0]; const index_t Wo = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) {
{ const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Pad0) {
{ const auto in_n_hi_wi_c_grid_desc =
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc,
in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else
else {
{ const index_t Y = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[0]; const index_t X = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationW = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[0]; const index_t InLeftPadW = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadW = input_right_pads[1];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc,
in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc,
in_n_hip_wip_c_grid_desc, make_tuple(
make_tuple( make_pass_through_transform(N),
make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_merge_transform(make_tuple(Y, X, C))),
make_merge_transform(make_tuple(Y, X, C))), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> static auto GetInputTensorDescriptor(ck::index_t N,
static auto GetInputTensorDescriptor(ck::index_t N, ck::index_t C,
ck::index_t C, ck::index_t gemm_m,
ck::index_t gemm_m, ck::index_t gemm_k,
ck::index_t gemm_k, ck::index_t gemm_m_pad,
ck::index_t gemm_m_pad, const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& input_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_strides, const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& conv_filter_dilations, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_right_pads)
const std::vector<ck::index_t>& input_right_pads) {
{ const index_t Di = input_spatial_lengths[0];
const index_t Di = input_spatial_lengths[0]; const index_t Hi = input_spatial_lengths[1];
const index_t Hi = input_spatial_lengths[1]; const index_t Wi = input_spatial_lengths[2];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Do = output_spatial_lengths[0]; const index_t Ho = output_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[1]; const index_t Wo = output_spatial_lengths[2];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) {
{ const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Pad0) {
{ const auto in_n_di_hi_wi_c_grid_desc =
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_grid_desc,
in_n_di_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( in_n_do_ho_wo_c_grid_desc,
in_n_do_ho_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else
else {
{ const index_t Z = filter_spatial_lengths[0];
const index_t Z = filter_spatial_lengths[0]; const index_t Y = filter_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[1]; const index_t X = filter_spatial_lengths[2];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationH = conv_filter_dilations[1]; const index_t ConvDilationW = conv_filter_dilations[2];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadD = input_left_pads[0]; const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadW = input_left_pads[2];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadW = input_right_pads[2];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_grid_desc,
in_n_di_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc,
in_n_hip_wip_c_grid_desc, make_tuple(
make_tuple( make_pass_through_transform(N),
make_pass_through_transform(N), make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{},
make_tuple(Sequence<0>{}, Sequence<1, 2>{},
Sequence<1, 2>{}, Sequence<3, 4>{},
Sequence<3, 4>{}, Sequence<5, 6>{},
Sequence<5, 6>{}, Sequence<7>{}));
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_grid_desc,
in_n_z_do_y_ho_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_merge_transform(make_tuple(Z, Y, X, C))),
make_merge_transform(make_tuple(Z, Y, X, C))), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} }
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths) {
{ return N * std::accumulate(std::begin(output_spatial_lengths),
return N * std::accumulate(std::begin(output_spatial_lengths), std::end(output_spatial_lengths),
std::end(output_spatial_lengths), 1,
1, std::multiplies<ck::index_t>());
std::multiplies<ck::index_t>()); }
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths) {
{ return C * std::accumulate(std::begin(filter_spatial_lengths),
return C * std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_spatial_lengths),
std::end(filter_spatial_lengths), 1,
1, std::multiplies<ck::index_t>());
std::multiplies<ck::index_t>()); }
}
static index_t GetGemmN(ck::index_t K)
static index_t GetGemmN(ck::index_t K) {
{ // return ck::math::integer_least_multiple(K,
// return ck::math::integer_least_multiple(K, // ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); return K;
return K; }
}
static auto MakeABCGridDescriptor(ck::index_t N,
static auto MakeABCGridDescriptor(ck::index_t N, ck::index_t K,
ck::index_t K, ck::index_t C,
ck::index_t C, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_right_pads)
std::vector<ck::index_t> input_right_pads) {
{ using namespace ck;
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmM = GetGemmM(N, output_spatial_lengths); const index_t GemmN = GetGemmN(K);
const index_t GemmN = GetGemmN(K); const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
// A: const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = GetInputTensorDescriptor<NumDimSpatial>(N,
GetInputTensorDescriptor<NumDimSpatial>(N, C,
C, GemmM,
GemmM, GemmK,
GemmK, input_spatial_lengths,
input_spatial_lengths, filter_spatial_lengths,
filter_spatial_lengths, output_spatial_lengths,
output_spatial_lengths, conv_filter_strides,
conv_filter_strides, conv_filter_dilations,
conv_filter_dilations, input_left_pads,
input_left_pads, input_right_pads);
input_right_pads); // B:
// B: const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN); // C:
// C: const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc); }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> static auto GetABCGridDesc()
static auto GetABCGridDesc() {
{ return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> static auto GetABCGridDesc()
static auto GetABCGridDesc() {
{ return MakeABCGridDescriptor(
return MakeABCGridDescriptor( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> static auto GetABCGridDesc()
static auto GetABCGridDesc() {
{ return MakeABCGridDescriptor(
return MakeABCGridDescriptor( 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); }
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr auto GetInputBlockDescriptor()
static constexpr auto GetInputBlockDescriptor() {
{ if constexpr(UseALocalBuffer)
if constexpr(UseALocalBuffer) {
{ // return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0)); }
} else
else {
{ return AGridDesc{};
return AGridDesc{}; }
} }
}
static constexpr auto GetWeightBlockDescriptor()
static constexpr auto GetWeightBlockDescriptor() {
{ if constexpr(UseBLocalBuffer)
if constexpr(UseBLocalBuffer) {
{ // return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock)); return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0)); }
} else
else {
{ return BGridDesc{};
return BGridDesc{}; }
} }
}
static constexpr auto GetOutputBlockDescriptor()
static constexpr auto GetOutputBlockDescriptor() {
{ if constexpr(UseCLocalBuffer)
if constexpr(UseCLocalBuffer) {
{ // return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0)); }
} else
else {
{ return CGridDesc{};
return CGridDesc{}; }
} }
}
// static constexpr bool UseCLocalBuffer = false;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
using AThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC< InDataType,
InDataType, InDataType,
InDataType, AGridDesc,
AGridDesc, decltype(GetInputBlockDescriptor()),
decltype(GetInputBlockDescriptor()), InElementwiseOperation,
InElementwiseOperation, !UseALocalBuffer,
!UseALocalBuffer, ConvForwardSpecialization>;
ConvForwardSpecialization,
GemmKSpecialization>; using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
using BThreadwiseCopy = WeiDataType,
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK< WeiDataType,
WeiDataType, BGridDesc,
WeiDataType, decltype(GetWeightBlockDescriptor()),
BGridDesc, WeiElementwiseOperation,
decltype(GetWeightBlockDescriptor()), !UseBLocalBuffer,
WeiElementwiseOperation, ConvForwardSpecialization>;
!UseBLocalBuffer,
ConvForwardSpecialization, using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
GemmKSpecialization>; OutDataType,
OutDataType,
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< CGridDesc,
OutDataType, decltype(GetOutputBlockDescriptor()),
OutDataType, OutElementwiseOperation,
CGridDesc, !UseCLocalBuffer,
decltype(GetOutputBlockDescriptor()), ConvForwardSpecialization>;
OutElementwiseOperation,
!UseCLocalBuffer, using GridwiseGemm =
ConvForwardSpecialization, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
GemmKSpecialization>; WeiDataType, // WeiDataType,
OutDataType, // OutDataType,
using GridwiseGemm = AGridDesc, // AGridDesc,
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, BGridDesc, // BGridDesc,
WeiDataType, // WeiDataType, CGridDesc, // CGridDesc,
OutDataType, // OutDataType, AElementwiseOperation, // AElementwiseOperation,
AGridDesc, // AGridDesc, BElementwiseOperation, // BElementwiseOperation,
BGridDesc, // BGridDesc, CElementwiseOperation, // CElementwiseOperation,
CGridDesc, // CGridDesc, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AElementwiseOperation, // AElementwiseOperation, AThreadwiseCopy, // AThreadwiseCopy
BElementwiseOperation, // BElementwiseOperation, BThreadwiseCopy, // BThreadwiseCopy
CElementwiseOperation, // CElementwiseOperation, CThreadwiseCopy, // CThreadwiseCopy
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ck::Sequence<0, 1>, // ThreadMNAccessOrder
AThreadwiseCopy, // AThreadwiseCopy UseALocalBuffer, // UseALocalBuffer
BThreadwiseCopy, // BThreadwiseCopy UseBLocalBuffer, // UseBLocalBuffer
CThreadwiseCopy, // CThreadwiseCopy UseCLocalBuffer // UseCLocalBuffer
ck::Sequence<0, 1>, // ThreadMNAccessOrder >;
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer GridwiseGemm gridwise_gemm;
UseCLocalBuffer // UseCLocalBuffer
>; // Argument
struct Argument : public BaseArgument
GridwiseGemm gridwise_gemm; {
Argument(const InDataType* p_in_grid,
// Argument const WeiDataType* p_wei_grid,
struct Argument : public BaseArgument OutDataType* p_out_grid,
{ ck::index_t N,
Argument(const InDataType* p_in_grid, ck::index_t K,
const WeiDataType* p_wei_grid, ck::index_t C,
OutDataType* p_out_grid, std::vector<ck::index_t> input_spatial_lengths,
ck::index_t N, std::vector<ck::index_t> filter_spatial_lengths,
ck::index_t K, std::vector<ck::index_t> output_spatial_lengths,
ck::index_t C, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> conv_filter_strides, InElementwiseOperation in_element_op,
std::vector<ck::index_t> conv_filter_dilations, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> input_left_pads, OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_right_pads, : p_a_grid_{p_in_grid},
InElementwiseOperation in_element_op, p_b_grid_{p_wei_grid},
WeiElementwiseOperation wei_element_op, p_c_grid_{p_out_grid},
OutElementwiseOperation out_element_op) a_grid_desc_{},
: p_a_grid_{p_in_grid}, b_grid_desc_{},
p_b_grid_{p_wei_grid}, c_grid_desc_{},
p_c_grid_{p_out_grid}, a_element_op_{in_element_op},
a_grid_desc_{}, b_element_op_{wei_element_op},
b_grid_desc_{}, c_element_op_{out_element_op},
c_grid_desc_{}, Conv_N_{N},
a_element_op_{in_element_op}, Conv_K_{K},
b_element_op_{wei_element_op}, Conv_C_{C},
c_element_op_{out_element_op}, filter_spatial_lengths_{filter_spatial_lengths},
Conv_N_{N}, conv_filter_strides_{conv_filter_strides},
Conv_K_{K}, input_left_pads_{input_left_pads},
Conv_C_{C}, input_right_pads_{input_right_pads}
filter_spatial_lengths_{filter_spatial_lengths}, {
conv_filter_strides_{conv_filter_strides}, const auto descs = DeviceOp::MakeABCGridDescriptor(N,
input_left_pads_{input_left_pads}, K,
input_right_pads_{input_right_pads} C,
{ input_spatial_lengths,
const auto descs = DeviceOp::MakeABCGridDescriptor(N, filter_spatial_lengths,
K, output_spatial_lengths,
C, conv_filter_strides,
input_spatial_lengths, conv_filter_dilations,
filter_spatial_lengths, input_left_pads,
output_spatial_lengths, input_right_pads);
conv_filter_strides, a_grid_desc_ = descs[I0];
conv_filter_dilations, b_grid_desc_ = descs[I1];
input_left_pads, c_grid_desc_ = descs[I2];
input_right_pads); }
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1]; // private:
c_grid_desc_ = descs[I2]; const ADataType* p_a_grid_;
} const BDataType* p_b_grid_;
CDataType* p_c_grid_;
// private: AGridDesc a_grid_desc_;
const ADataType* p_a_grid_; BGridDesc b_grid_desc_;
const BDataType* p_b_grid_; CGridDesc c_grid_desc_;
CDataType* p_c_grid_;
AGridDesc a_grid_desc_; AElementwiseOperation a_element_op_;
BGridDesc b_grid_desc_; BElementwiseOperation b_element_op_;
CGridDesc c_grid_desc_; CElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
AElementwiseOperation a_element_op_; index_t Conv_N_;
BElementwiseOperation b_element_op_; index_t Conv_K_;
CElementwiseOperation c_element_op_; index_t Conv_C_;
// for checking IsSupportedArgument() std::vector<index_t> filter_spatial_lengths_;
index_t Conv_N_; std::vector<index_t> conv_filter_strides_;
index_t Conv_K_; std::vector<index_t> input_left_pads_;
index_t Conv_C_; std::vector<index_t> input_right_pads_;
std::vector<index_t> filter_spatial_lengths_; };
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_; // Invoker
std::vector<index_t> input_right_pads_; struct Invoker : public BaseInvoker
}; {
using Argument = DeviceOp::Argument;
// Invoker GridwiseGemm gridwise_gemm;
struct Invoker : public BaseInvoker
{ Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm; float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {} int nrepeat = 1)
{
float Run(const Argument& arg, if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
const StreamConfig& stream_config = StreamConfig{}, {
int nrepeat = 1) throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
{ }
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType,
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); WeiDataType,
OutDataType,
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm, AGridDesc,
InDataType, BGridDesc,
WeiDataType, CGridDesc,
OutDataType, AElementwiseOperation,
AGridDesc, BElementwiseOperation,
BGridDesc, CElementwiseOperation>;
CGridDesc,
AElementwiseOperation, float ave_time = 0;
BElementwiseOperation,
CElementwiseOperation>; if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
float ave_time = 0; nrepeat,
gridwise_gemm,
if(nrepeat != 1) arg.p_a_grid_,
ave_time = launch_and_time_cpu_kernel(kernel, arg.p_b_grid_,
nrepeat, arg.p_c_grid_,
gridwise_gemm, arg.a_grid_desc_,
arg.p_a_grid_, arg.b_grid_desc_,
arg.p_b_grid_, arg.c_grid_desc_,
arg.p_c_grid_, arg.a_element_op_,
arg.a_grid_desc_, arg.b_element_op_,
arg.b_grid_desc_, arg.c_element_op_);
arg.c_grid_desc_,
arg.a_element_op_, // TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
arg.b_element_op_, // result
arg.c_element_op_); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the launch_cpu_kernel(kernel,
// result gridwise_gemm,
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); arg.p_a_grid_,
arg.p_b_grid_,
launch_cpu_kernel(kernel, arg.p_c_grid_,
gridwise_gemm, arg.a_grid_desc_,
arg.p_a_grid_, arg.b_grid_desc_,
arg.p_b_grid_, arg.c_grid_desc_,
arg.p_c_grid_, arg.a_element_op_,
arg.a_grid_desc_, arg.b_element_op_,
arg.b_grid_desc_, arg.c_element_op_);
arg.c_grid_desc_,
arg.a_element_op_, return ave_time;
arg.b_element_op_, }
arg.c_element_op_);
float Run(const BaseArgument* p_arg,
return ave_time; const StreamConfig& stream_config = StreamConfig{},
} int nrepeat = 1) override
{
float Run(const BaseArgument* p_arg, return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
const StreamConfig& stream_config = StreamConfig{}, }
int nrepeat = 1) override };
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat); static constexpr bool IsValidCompilationParameter()
} {
}; // TODO: properly implement this check
return true;
static constexpr bool IsValidCompilationParameter() }
{
// TODO: properly implement this check bool IsSupportedArgument(const Argument& arg)
return true; {
} if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
bool IsSupportedArgument(const Argument& arg) {
{ // check if it's 1x1, stride=1 conv
if constexpr(ConvForwardSpecialization == if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
{ arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
// check if it's 1x1, stride=1 conv arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && {
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && return false;
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && }
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) }
{ else if constexpr(ConvForwardSpecialization ==
return false; ConvolutionForwardSpecialization_t::Filter1x1Pad0)
} {
} // check if it's 1x1 conv
else if constexpr(ConvForwardSpecialization == if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
ConvolutionForwardSpecialization_t::Filter1x1Pad0) arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
{ arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
// check if it's 1x1 conv {
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && return false;
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && }
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) }
{
return false; if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
} ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
} ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if constexpr(GemmKSpecialization == if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && return false;
ConvForwardSpecialization != }
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ if constexpr(!UseALocalBuffer &&
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) ConvForwardSpecialization !=
return false; ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
} {
// TODO: We can support this in the future, as long as figure out how to express tensor
if constexpr(!UseALocalBuffer && // transform
ConvForwardSpecialization != return false;
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) }
{
// TODO: We can support this in the future, as long as figure out how to express tensor if constexpr(!UseBLocalBuffer)
// transform {
return false; if(!(arg.Conv_K_ % 8 == 0))
} return false;
}
if constexpr(!UseBLocalBuffer)
{ // Gridwise GEMM size
if(!(arg.Conv_K_ % 8 == 0)) return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return false; }
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
// Gridwise GEMM size {
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override static auto MakeArgument(const InDataType* p_in_grid,
{ const WeiDataType* p_wei_grid,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); OutDataType* p_out_grid,
} ck::index_t N,
ck::index_t K,
static auto MakeArgument(const InDataType* p_in_grid, ck::index_t C,
const WeiDataType* p_wei_grid, std::vector<ck::index_t> input_spatial_lengths,
OutDataType* p_out_grid, std::vector<ck::index_t> filter_spatial_lengths,
ck::index_t N, std::vector<ck::index_t> output_spatial_lengths,
ck::index_t K, std::vector<ck::index_t> conv_filter_strides,
ck::index_t C, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> output_spatial_lengths, InElementwiseOperation in_element_op,
std::vector<ck::index_t> conv_filter_strides, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> conv_filter_dilations, OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_left_pads, {
std::vector<ck::index_t> input_right_pads, return Argument{p_in_grid,
InElementwiseOperation in_element_op, p_wei_grid,
WeiElementwiseOperation wei_element_op, p_out_grid,
OutElementwiseOperation out_element_op) N,
{ K,
return Argument{p_in_grid, C,
p_wei_grid, input_spatial_lengths,
p_out_grid, filter_spatial_lengths,
N, output_spatial_lengths,
K, conv_filter_strides,
C, conv_filter_dilations,
input_spatial_lengths, input_left_pads,
filter_spatial_lengths, input_right_pads,
output_spatial_lengths, in_element_op,
conv_filter_strides, wei_element_op,
conv_filter_dilations, out_element_op};
input_left_pads, }
input_right_pads,
in_element_op, auto MakeInvoker() { return Invoker{gridwise_gemm}; }
wei_element_op,
out_element_op}; std::unique_ptr<BaseArgument>
} MakeArgumentPointer(const void* p_in_grid,
const void* p_wei_grid,
auto MakeInvoker() { return Invoker{gridwise_gemm}; } void* p_out_grid,
ck::index_t N,
std::unique_ptr<BaseArgument> ck::index_t K,
MakeArgumentPointer(const void* p_in_grid, ck::index_t C,
const void* p_wei_grid, std::vector<ck::index_t> input_spatial_lengths,
void* p_out_grid, std::vector<ck::index_t> filter_spatial_lengths,
ck::index_t N, std::vector<ck::index_t> output_spatial_lengths,
ck::index_t K, std::vector<ck::index_t> conv_filter_strides,
ck::index_t C, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> output_spatial_lengths, InElementwiseOperation in_element_op,
std::vector<ck::index_t> conv_filter_strides, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> conv_filter_dilations, OutElementwiseOperation out_element_op) override
std::vector<ck::index_t> input_left_pads, {
std::vector<ck::index_t> input_right_pads, return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
InElementwiseOperation in_element_op, static_cast<const WeiDataType*>(p_wei_grid),
WeiElementwiseOperation wei_element_op, static_cast<OutDataType*>(p_out_grid),
OutElementwiseOperation out_element_op) override N,
{ K,
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), C,
static_cast<const WeiDataType*>(p_wei_grid), input_spatial_lengths,
static_cast<OutDataType*>(p_out_grid), filter_spatial_lengths,
N, output_spatial_lengths,
K, conv_filter_strides,
C, conv_filter_dilations,
input_spatial_lengths, input_left_pads,
filter_spatial_lengths, input_right_pads,
output_spatial_lengths, in_element_op,
conv_filter_strides, wei_element_op,
conv_filter_dilations, out_element_op);
input_left_pads, }
input_right_pads,
in_element_op, std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
wei_element_op, {
out_element_op); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::string GetTypeString() const override
{ {
return std::make_unique<Invoker>(Invoker{gridwise_gemm}); auto str = std::stringstream();
} auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
std::string GetTypeString() const override return "L";
{ else
auto str = std::stringstream(); return "G";
auto string_local_buffer = [](bool is_local_buffer) { };
if(is_local_buffer) // clang-format off
return "L"; str << "DeviceConv" << std::to_string(NumDimSpatial)
else << "DFwdAvx2_NHWC_YXCK"
return "G"; <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
}; <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
// clang-format off <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
str << "DeviceConv" << std::to_string(NumDimSpatial) << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "DFwdAvx2_NHWC_YXCK" << "_A" << string_local_buffer(UseALocalBuffer)
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) << "_B" << string_local_buffer(UseBLocalBuffer)
<<"_KS"<< static_cast<int>(GemmKSpecialization) << "_C" << string_local_buffer(UseCLocalBuffer)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) ;
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block if constexpr (!std::is_same<OutElementwiseOperation,
<< "_A" << string_local_buffer(UseALocalBuffer) ck::tensor_operation::cpu::element_wise::PassThrough>::value)
<< "_B" << string_local_buffer(UseBLocalBuffer) {
<< "_C" << string_local_buffer(UseCLocalBuffer) str << "_" << OutElementwiseOperation::Name();
; }
if constexpr (!std::is_same<OutElementwiseOperation, // clang-format on
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{ return str.str();
str << "_" << OutElementwiseOperation::Name(); }
} };
// clang-format on
} // namespace device
return str.str(); } // namespace cpu
} } // namespace tensor_operation
}; } // namespace ck
} // namespace device #endif
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -31,7 +31,6 @@ template <typename InDataType, ...@@ -31,7 +31,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -596,8 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -596,8 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
...@@ -607,8 +605,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -607,8 +605,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
...@@ -855,10 +852,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -855,10 +852,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -981,7 +977,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -981,7 +977,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXC" << "DFwd_BAA_Avx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
...@@ -31,7 +31,6 @@ template <typename InDataType, ...@@ -31,7 +31,6 @@ template <typename InDataType,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
...@@ -573,8 +572,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -573,8 +572,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
!UseALocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using BThreadwiseCopy = using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8<
...@@ -584,8 +582,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -584,8 +582,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization>;
GemmKSpecialization>;
using CThreadwiseCopy = using CThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
...@@ -832,10 +829,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -832,10 +829,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
} }
} }
if constexpr(GemmKSpecialization == if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
...@@ -961,7 +957,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -961,7 +957,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXCK8" << "DFwd_BAA_Avx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
......
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP #ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP #define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#include "device.hpp" #include "device.hpp"
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp" #include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp" #include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp" #include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp" #include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp" #include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename BiasDataType, typename BiasDataType,
typename AddDataType, typename AddDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ck::index_t NumDimSpatial,
ck::index_t NumDimSpatial, ck::index_t MPerThread,
ck::index_t MPerThread, ck::index_t NPerThread,
ck::index_t NPerThread, bool UseALocalBuffer,
bool UseALocalBuffer, bool UseBLocalBuffer,
bool UseBLocalBuffer, bool UseCLocalBuffer,
bool UseCLocalBuffer, bool BiasAlongGemmM>
bool BiasAlongGemmM> struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, WeiElementwiseOperation,
WeiElementwiseOperation, OutElementwiseOperation>
OutElementwiseOperation> {
{ using DeviceOp =
using DeviceOp = DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using ADataType = InDataType; using BDataType = WeiDataType;
using BDataType = WeiDataType; using CDataType = OutDataType;
using CDataType = OutDataType; using C0DataType = BiasDataType;
using C0DataType = BiasDataType; using C1DataType = AddDataType;
using C1DataType = AddDataType;
using AElementwiseOperation = InElementwiseOperation;
using AElementwiseOperation = InElementwiseOperation; using BElementwiseOperation = WeiElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation; using CElementwiseOperation = OutElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different
// TODO make A/B datatype different using ABDataType = InDataType;
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{};
static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr bool NonTemporalStore = false;
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K(
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K( const DeviceConvFwdDynamicTunable& dtune)
const DeviceConvFwdDynamicTunable& dtune) : gridwise_gemm(dtune)
: gridwise_gemm(dtune) {
{ }
}
static constexpr auto GetThreadwiseGemm_Dispatch()
static constexpr auto GetThreadwiseGemm_Dispatch() {
{ if constexpr(MPerThread == 4 && NPerThread == 24)
if constexpr(MPerThread == 4 && NPerThread == 24) {
{ return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType,
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InDataType, WeiDataType,
WeiDataType, OutDataType,
OutDataType, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, NonTemporalStore>{};
NonTemporalStore>{}; }
} else if constexpr(MPerThread == 6 && NPerThread == 16)
else if constexpr(MPerThread == 6 && NPerThread == 16) {
{ return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType,
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<InDataType, WeiDataType,
WeiDataType, OutDataType,
OutDataType, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor, NonTemporalStore>{};
NonTemporalStore>{}; }
} else
else {
{ // static_assert(false, "invalid Mr/Nr");
// static_assert(false, "invalid Mr/Nr"); }
} }
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) {
{ return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n)); }
}
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n) {
{ const auto out_gemm_m_n_grid_desc =
const auto out_gemm_m_n_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
return out_gemm_m_n_grid_desc;
return out_gemm_m_n_grid_desc; }
}
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n) {
{ if constexpr(BiasAlongGemmM)
if constexpr(BiasAlongGemmM) {
{ return make_naive_tensor_descriptor_packed(make_tuple(gemm_m));
return make_naive_tensor_descriptor_packed(make_tuple(gemm_m)); }
} else
else {
{ return make_naive_tensor_descriptor_packed(make_tuple(gemm_n));
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n)); }
} }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> static auto GetInputTensorDescriptor(ck::index_t N,
static auto GetInputTensorDescriptor(ck::index_t N, ck::index_t C,
ck::index_t C, ck::index_t gemm_m,
ck::index_t gemm_m, ck::index_t gemm_k,
ck::index_t gemm_k, const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& input_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_strides, const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& conv_filter_dilations, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_right_pads)
const std::vector<ck::index_t>& input_right_pads) {
{ const index_t Wi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[0]; const index_t Wo = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0]; const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) {
{ const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Pad0) {
{ const auto in_n_wi_c_grid_desc =
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( in_n_wi_c_grid_desc,
in_n_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( in_n_wo_c_grid_desc,
in_n_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else
else {
{ const index_t X = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0]; const index_t ConvDilationW = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[0]; const index_t InRightPadW = input_right_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( in_n_wi_c_grid_desc,
in_n_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_wip_c_grid_desc,
in_n_wip_c_grid_desc, make_tuple(
make_tuple( make_pass_through_transform(N),
make_pass_through_transform(N), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
transform_tensor_descriptor(in_n_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_merge_transform(make_tuple(X, C))),
make_merge_transform(make_tuple(X, C))), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> static auto GetInputTensorDescriptor(ck::index_t N,
static auto GetInputTensorDescriptor(ck::index_t N, ck::index_t C,
ck::index_t C, ck::index_t gemm_m,
ck::index_t gemm_m, ck::index_t gemm_k,
ck::index_t gemm_k, const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& input_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_strides, const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& conv_filter_dilations, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_right_pads)
const std::vector<ck::index_t>& input_right_pads) {
{ const index_t Hi = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[0]; const index_t Wi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[0]; const index_t Wo = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) {
{ const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Pad0) {
{ const auto in_n_hi_wi_c_grid_desc =
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc,
in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else
else {
{ const index_t Y = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[0]; const index_t X = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[0]; const index_t ConvDilationW = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[0]; const index_t InLeftPadW = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadW = input_right_pads[1];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc,
in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc,
in_n_hip_wip_c_grid_desc, make_tuple(
make_tuple( make_pass_through_transform(N),
make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_merge_transform(make_tuple(Y, X, C))),
make_merge_transform(make_tuple(Y, X, C))), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> static auto GetInputTensorDescriptor(ck::index_t N,
static auto GetInputTensorDescriptor(ck::index_t N, ck::index_t C,
ck::index_t C, ck::index_t gemm_m,
ck::index_t gemm_m, ck::index_t gemm_k,
ck::index_t gemm_k, ck::index_t gemm_m_pad,
ck::index_t gemm_m_pad, const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& input_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_strides, const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& conv_filter_dilations, const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_left_pads, const std::vector<ck::index_t>& input_right_pads)
const std::vector<ck::index_t>& input_right_pads) {
{ const index_t Di = input_spatial_lengths[0];
const index_t Di = input_spatial_lengths[0]; const index_t Hi = input_spatial_lengths[1];
const index_t Hi = input_spatial_lengths[1]; const index_t Wi = input_spatial_lengths[2];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Do = output_spatial_lengths[0]; const index_t Ho = output_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[1]; const index_t Wo = output_spatial_lengths[2];
const index_t Wo = output_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) {
{ const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0)
ConvolutionForwardSpecialization_t::Filter1x1Pad0) {
{ const auto in_n_di_hi_wi_c_grid_desc =
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_grid_desc,
in_n_di_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( in_n_do_ho_wo_c_grid_desc,
in_n_do_ho_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else
else {
{ const index_t Z = filter_spatial_lengths[0];
const index_t Z = filter_spatial_lengths[0]; const index_t Y = filter_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[1]; const index_t X = filter_spatial_lengths[2];
const index_t X = filter_spatial_lengths[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationH = conv_filter_dilations[1]; const index_t ConvDilationW = conv_filter_dilations[2];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadD = input_left_pads[0]; const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadW = input_left_pads[2];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadW = input_right_pads[2];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_grid_desc,
in_n_di_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N),
make_tuple(make_pass_through_transform(N), make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc,
in_n_hip_wip_c_grid_desc, make_tuple(
make_tuple( make_pass_through_transform(N),
make_pass_through_transform(N), make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_pass_through_transform(C)),
make_pass_through_transform(C)), make_tuple(
make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{},
make_tuple(Sequence<0>{}, Sequence<1, 2>{},
Sequence<1, 2>{}, Sequence<3, 4>{},
Sequence<3, 4>{}, Sequence<5, 6>{},
Sequence<5, 6>{}, Sequence<7>{}));
Sequence<7>{}));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_grid_desc,
in_n_z_do_y_ho_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_merge_transform(make_tuple(Z, Y, X, C))),
make_merge_transform(make_tuple(Z, Y, X, C))), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} }
}
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths) {
{ return N * std::accumulate(std::begin(output_spatial_lengths),
return N * std::accumulate(std::begin(output_spatial_lengths), std::end(output_spatial_lengths),
std::end(output_spatial_lengths), 1,
1, std::multiplies<ck::index_t>());
std::multiplies<ck::index_t>()); }
}
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths) {
{ return C * std::accumulate(std::begin(filter_spatial_lengths),
return C * std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_spatial_lengths),
std::end(filter_spatial_lengths), 1,
1, std::multiplies<ck::index_t>());
std::multiplies<ck::index_t>()); }
}
static index_t GetGemmN(ck::index_t K)
static index_t GetGemmN(ck::index_t K) {
{ // return ck::math::integer_least_multiple(K,
// return ck::math::integer_least_multiple(K, // ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); return K;
return K; }
}
static auto MakeABCGridDescriptor(ck::index_t N,
static auto MakeABCGridDescriptor(ck::index_t N, ck::index_t K,
ck::index_t K, ck::index_t C,
ck::index_t C, std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_right_pads)
std::vector<ck::index_t> input_right_pads) {
{ using namespace ck;
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmM = GetGemmM(N, output_spatial_lengths); const index_t GemmN = GetGemmN(K);
const index_t GemmN = GetGemmN(K); const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
// A: const auto in_gemm_m_k_grid_desc =
const auto in_gemm_m_k_grid_desc = GetInputTensorDescriptor<NumDimSpatial>(N,
GetInputTensorDescriptor<NumDimSpatial>(N, C,
C, GemmM,
GemmM, GemmK,
GemmK, input_spatial_lengths,
input_spatial_lengths, filter_spatial_lengths,
filter_spatial_lengths, output_spatial_lengths,
output_spatial_lengths, conv_filter_strides,
conv_filter_strides, conv_filter_dilations,
conv_filter_dilations, input_left_pads,
input_left_pads, input_right_pads);
input_right_pads); // B:
// B: const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
const auto wei_gemm_k_n_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN); // C:
// C: const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc);
return make_tuple(in_gemm_m_k_grid_desc, wei_gemm_k_n_grid_desc, out_gemm_m_n_grid_desc); }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> static auto GetABCGridDesc()
static auto GetABCGridDesc() {
{ return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> static auto GetABCGridDesc()
static auto GetABCGridDesc() {
{ return MakeABCGridDescriptor(
return MakeABCGridDescriptor( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); }
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> static auto GetABCGridDesc()
static auto GetABCGridDesc() {
{ return MakeABCGridDescriptor(
return MakeABCGridDescriptor( 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); }
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>; using C1GridDesc = CGridDesc;
using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
static constexpr auto GetInputBlockDescriptor() {
{ if constexpr(UseALocalBuffer)
if constexpr(UseALocalBuffer) {
{ return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0)); }
} else
else {
{ return AGridDesc{};
return AGridDesc{}; }
} }
}
static constexpr auto GetWeightBlockDescriptor()
static constexpr auto GetWeightBlockDescriptor() {
{ if constexpr(UseBLocalBuffer)
if constexpr(UseBLocalBuffer) {
{ return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0)); }
} else
else {
{ return BGridDesc{};
return BGridDesc{}; }
} }
}
static constexpr auto GetOutputBlockDescriptor()
static constexpr auto GetOutputBlockDescriptor() {
{ if constexpr(UseCLocalBuffer)
if constexpr(UseCLocalBuffer) {
{ return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0)); }
} else
else {
{ return CGridDesc{};
return CGridDesc{}; }
} }
}
// static constexpr bool UseCLocalBuffer = false;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
using AThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC< ADataType,
ADataType, ADataType,
ADataType, AGridDesc,
AGridDesc, decltype(GetInputBlockDescriptor()),
decltype(GetInputBlockDescriptor()), InElementwiseOperation,
InElementwiseOperation, !UseALocalBuffer,
!UseALocalBuffer, ConvForwardSpecialization>;
ConvForwardSpecialization,
GemmKSpecialization>; using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK<
using BThreadwiseCopy = BDataType,
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK< BDataType,
BDataType, BGridDesc,
BDataType, decltype(GetWeightBlockDescriptor()),
BGridDesc, WeiElementwiseOperation,
decltype(GetWeightBlockDescriptor()), !UseBLocalBuffer,
WeiElementwiseOperation, ConvForwardSpecialization>;
!UseBLocalBuffer,
ConvForwardSpecialization, using CThreadwiseCopy =
GemmKSpecialization>; ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType,
using CThreadwiseCopy = C0DataType,
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< C1DataType,
CDataType, CDataType,
C0DataType, CGridDesc,
C1DataType, C0GridDesc,
CDataType, C1GridDesc,
CGridDesc, decltype(GetOutputBlockDescriptor()),
C0GridDesc, OutElementwiseOperation,
C1GridDesc, !UseCLocalBuffer,
decltype(GetOutputBlockDescriptor()), BiasAlongGemmM>;
OutElementwiseOperation,
!UseCLocalBuffer, using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
BiasAlongGemmM>; ADataType, // InDataType,
BDataType, // WeiDataType,
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< CDataType, // OutDataType,
ADataType, // InDataType, C0DataType, // C0DataType
BDataType, // WeiDataType, C1DataType, // C1DataType
CDataType, // OutDataType, AGridDesc, // AGridDesc,
C0DataType, // C0DataType BGridDesc, // BGridDesc,
C1DataType, // C1DataType CGridDesc, // CGridDesc,
AGridDesc, // AGridDesc, C0GridDesc, // C0GridDesc,
BGridDesc, // BGridDesc, C1GridDesc, // C1GridDesc,
CGridDesc, // CGridDesc, AElementwiseOperation, // AElementwiseOperation,
C0GridDesc, // C0GridDesc, BElementwiseOperation, // BElementwiseOperation,
C1GridDesc, // C1GridDesc, CElementwiseOperation, // CElementwiseOperation,
AElementwiseOperation, // AElementwiseOperation, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
BElementwiseOperation, // BElementwiseOperation, AThreadwiseCopy, // AThreadwiseCopy
CElementwiseOperation, // CElementwiseOperation, BThreadwiseCopy, // BThreadwiseCopy
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, CThreadwiseCopy, // CThreadwiseCopy
AThreadwiseCopy, // AThreadwiseCopy ck::Sequence<0, 1>, // ThreadMNAccessOrder
BThreadwiseCopy, // BThreadwiseCopy UseALocalBuffer, // UseALocalBuffer
CThreadwiseCopy, // CThreadwiseCopy UseBLocalBuffer, // UseBLocalBuffer
ck::Sequence<0, 1>, // ThreadMNAccessOrder UseCLocalBuffer // UseCLocalBuffer
UseALocalBuffer, // UseALocalBuffer >;
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer GridwiseGemm gridwise_gemm;
>;
// Argument
GridwiseGemm gridwise_gemm; struct Argument : public BaseArgument
{
// Argument Argument(const InDataType* p_in_grid,
struct Argument : public BaseArgument const WeiDataType* p_wei_grid,
{ OutDataType* p_out_grid,
Argument(const InDataType* p_in_grid, const BiasDataType* p_bias_grid,
const WeiDataType* p_wei_grid, const AddDataType* p_add_grid,
OutDataType* p_out_grid, ck::index_t N,
const BiasDataType* p_bias_grid, ck::index_t K,
const AddDataType* p_add_grid, ck::index_t C,
ck::index_t N, std::vector<ck::index_t> input_spatial_lengths,
ck::index_t K, std::vector<ck::index_t> filter_spatial_lengths,
ck::index_t C, std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> conv_filter_dilations, InElementwiseOperation in_element_op,
std::vector<ck::index_t> input_left_pads, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> input_right_pads, OutElementwiseOperation out_element_op)
InElementwiseOperation in_element_op, : p_a_grid_{p_in_grid},
WeiElementwiseOperation wei_element_op, p_b_grid_{p_wei_grid},
OutElementwiseOperation out_element_op) p_c_grid_{p_out_grid},
: p_a_grid_{p_in_grid}, p_c0_grid_{p_bias_grid},
p_b_grid_{p_wei_grid}, p_c1_grid_{p_add_grid},
p_c_grid_{p_out_grid}, a_grid_desc_{},
p_c0_grid_{p_bias_grid}, b_grid_desc_{},
p_c1_grid_{p_add_grid}, c_grid_desc_{},
a_grid_desc_{}, c0_grid_desc_{},
b_grid_desc_{}, c1_grid_desc_{},
c_grid_desc_{}, a_element_op_{in_element_op},
c0_grid_desc_{}, b_element_op_{wei_element_op},
c1_grid_desc_{}, c_element_op_{out_element_op},
a_element_op_{in_element_op}, Conv_N_{N},
b_element_op_{wei_element_op}, Conv_K_{K},
c_element_op_{out_element_op}, Conv_C_{C},
Conv_N_{N}, filter_spatial_lengths_{filter_spatial_lengths},
Conv_K_{K}, conv_filter_strides_{conv_filter_strides},
Conv_C_{C}, input_left_pads_{input_left_pads},
filter_spatial_lengths_{filter_spatial_lengths}, input_right_pads_{input_right_pads}
conv_filter_strides_{conv_filter_strides}, {
input_left_pads_{input_left_pads}, const auto descs = DeviceOp::MakeABCGridDescriptor(N,
input_right_pads_{input_right_pads} K,
{ C,
const auto descs = DeviceOp::MakeABCGridDescriptor(N, input_spatial_lengths,
K, filter_spatial_lengths,
C, output_spatial_lengths,
input_spatial_lengths, conv_filter_strides,
filter_spatial_lengths, conv_filter_dilations,
output_spatial_lengths, input_left_pads,
conv_filter_strides, input_right_pads);
conv_filter_dilations, a_grid_desc_ = descs[I0];
input_left_pads, b_grid_desc_ = descs[I1];
input_right_pads); c_grid_desc_ = descs[I2];
a_grid_desc_ = descs[I0];
b_grid_desc_ = descs[I1]; c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths),
c_grid_desc_ = descs[I2]; GetGemmN(K));
c1_grid_desc_ = descs[I2];
c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths), }
GetGemmN(K));
c1_grid_desc_ = descs[I2]; // private:
} const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
// private: CDataType* p_c_grid_;
const ADataType* p_a_grid_; const C0DataType* p_c0_grid_;
const BDataType* p_b_grid_; const C1DataType* p_c1_grid_;
CDataType* p_c_grid_; AGridDesc a_grid_desc_;
const C0DataType* p_c0_grid_; BGridDesc b_grid_desc_;
const C1DataType* p_c1_grid_; CGridDesc c_grid_desc_;
AGridDesc a_grid_desc_; C0GridDesc c0_grid_desc_;
BGridDesc b_grid_desc_; C1GridDesc c1_grid_desc_;
CGridDesc c_grid_desc_;
C0GridDesc c0_grid_desc_; AElementwiseOperation a_element_op_;
C1GridDesc c1_grid_desc_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
AElementwiseOperation a_element_op_; // for checking IsSupportedArgument()
BElementwiseOperation b_element_op_; index_t Conv_N_;
CElementwiseOperation c_element_op_; index_t Conv_K_;
// for checking IsSupportedArgument() index_t Conv_C_;
index_t Conv_N_; std::vector<index_t> filter_spatial_lengths_;
index_t Conv_K_; std::vector<index_t> conv_filter_strides_;
index_t Conv_C_; std::vector<index_t> input_left_pads_;
std::vector<index_t> filter_spatial_lengths_; std::vector<index_t> input_right_pads_;
std::vector<index_t> conv_filter_strides_; };
std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_; // Invoker
}; struct Invoker : public BaseInvoker
{
// Invoker using Argument = DeviceOp::Argument;
struct Invoker : public BaseInvoker
{ GridwiseGemm gridwise_gemm;
using Argument = DeviceOp::Argument;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
GridwiseGemm gridwise_gemm;
float Run(const Argument& arg,
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {} const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
float Run(const Argument& arg, {
const StreamConfig& stream_config = StreamConfig{}, if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
int nrepeat = 1) {
{ throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) }
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
}
const auto kernel =
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm,
ADataType,
const auto kernel = BDataType,
ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm, CDataType,
ADataType, C0DataType,
BDataType, C1DataType,
CDataType, AGridDesc,
C0DataType, BGridDesc,
C1DataType, CGridDesc,
AGridDesc, C0GridDesc,
BGridDesc, C1GridDesc,
CGridDesc, AElementwiseOperation,
C0GridDesc, BElementwiseOperation,
C1GridDesc, CElementwiseOperation>;
AElementwiseOperation,
BElementwiseOperation, float ave_time = 0;
CElementwiseOperation>;
if(nrepeat != 1)
float ave_time = 0; ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
if(nrepeat != 1) gridwise_gemm,
ave_time = launch_and_time_cpu_kernel(kernel, arg.p_a_grid_,
nrepeat, arg.p_b_grid_,
gridwise_gemm, arg.p_c_grid_,
arg.p_a_grid_, arg.p_c0_grid_,
arg.p_b_grid_, arg.p_c1_grid_,
arg.p_c_grid_, arg.a_grid_desc_,
arg.p_c0_grid_, arg.b_grid_desc_,
arg.p_c1_grid_, arg.c_grid_desc_,
arg.a_grid_desc_, arg.c0_grid_desc_,
arg.b_grid_desc_, arg.c1_grid_desc_,
arg.c_grid_desc_, arg.a_element_op_,
arg.c0_grid_desc_, arg.b_element_op_,
arg.c1_grid_desc_, arg.c_element_op_);
arg.a_element_op_,
arg.b_element_op_, // TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
arg.c_element_op_); // result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result launch_cpu_kernel(kernel,
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); gridwise_gemm,
arg.p_a_grid_,
launch_cpu_kernel(kernel, arg.p_b_grid_,
gridwise_gemm, arg.p_c_grid_,
arg.p_a_grid_, arg.p_c0_grid_,
arg.p_b_grid_, arg.p_c1_grid_,
arg.p_c_grid_, arg.a_grid_desc_,
arg.p_c0_grid_, arg.b_grid_desc_,
arg.p_c1_grid_, arg.c_grid_desc_,
arg.a_grid_desc_, arg.c0_grid_desc_,
arg.b_grid_desc_, arg.c1_grid_desc_,
arg.c_grid_desc_, arg.a_element_op_,
arg.c0_grid_desc_, arg.b_element_op_,
arg.c1_grid_desc_, arg.c_element_op_);
arg.a_element_op_,
arg.b_element_op_, return ave_time;
arg.c_element_op_); }
return ave_time; float Run(const BaseArgument* p_arg,
} const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) override
float Run(const BaseArgument* p_arg, {
const StreamConfig& stream_config = StreamConfig{}, return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
int nrepeat = 1) override }
{ };
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
} static constexpr bool IsValidCompilationParameter()
}; {
// TODO: properly implement this check
static constexpr bool IsValidCompilationParameter() return true;
{ }
// TODO: properly implement this check
return true; bool IsSupportedArgument(const Argument& arg)
} {
if constexpr(ConvForwardSpecialization ==
bool IsSupportedArgument(const Argument& arg) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if constexpr(ConvForwardSpecialization == // check if it's 1x1, stride=1 conv
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
{ arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
// check if it's 1x1, stride=1 conv arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && {
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && return false;
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) }
{ }
return false; else if constexpr(ConvForwardSpecialization ==
} ConvolutionForwardSpecialization_t::Filter1x1Pad0)
} {
else if constexpr(ConvForwardSpecialization == // check if it's 1x1 conv
ConvolutionForwardSpecialization_t::Filter1x1Pad0) if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
{ arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
// check if it's 1x1 conv arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && {
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && return false;
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) }
{ }
return false;
} if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
} ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
if constexpr(GemmKSpecialization == {
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC && if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
ConvForwardSpecialization != return false;
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) }
{
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0)) if constexpr(!UseALocalBuffer &&
return false; ConvForwardSpecialization !=
} ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if constexpr(!UseALocalBuffer && // TODO: We can support this in the future, as long as figure out how to express tensor
ConvForwardSpecialization != // transform
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) return false;
{ }
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform if constexpr(!UseBLocalBuffer)
return false; {
} if(!(arg.Conv_K_ % 8 == 0))
return false;
if constexpr(!UseBLocalBuffer) }
{
if(!(arg.Conv_K_ % 8 == 0)) // Gridwise GEMM size
return false; return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
// Gridwise GEMM size bool IsSupportedArgument(const BaseArgument* p_arg) override
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); {
} return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{ static auto MakeArgument(const InDataType* p_in_grid,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); const WeiDataType* p_wei_grid,
} OutDataType* p_out_grid,
const BiasDataType* p_bias_grid,
static auto MakeArgument(const InDataType* p_in_grid, const AddDataType* p_add_grid,
const WeiDataType* p_wei_grid, ck::index_t N,
OutDataType* p_out_grid, ck::index_t K,
const BiasDataType* p_bias_grid, ck::index_t C,
const AddDataType* p_add_grid, std::vector<ck::index_t> input_spatial_lengths,
ck::index_t N, std::vector<ck::index_t> filter_spatial_lengths,
ck::index_t K, std::vector<ck::index_t> output_spatial_lengths,
ck::index_t C, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> conv_filter_strides, InElementwiseOperation in_element_op,
std::vector<ck::index_t> conv_filter_dilations, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> input_left_pads, OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_right_pads, {
InElementwiseOperation in_element_op, return Argument{p_in_grid,
WeiElementwiseOperation wei_element_op, p_wei_grid,
OutElementwiseOperation out_element_op) p_out_grid,
{ p_bias_grid,
return Argument{p_in_grid, p_add_grid,
p_wei_grid, N,
p_out_grid, K,
p_bias_grid, C,
p_add_grid, input_spatial_lengths,
N, filter_spatial_lengths,
K, output_spatial_lengths,
C, conv_filter_strides,
input_spatial_lengths, conv_filter_dilations,
filter_spatial_lengths, input_left_pads,
output_spatial_lengths, input_right_pads,
conv_filter_strides, in_element_op,
conv_filter_dilations, wei_element_op,
input_left_pads, out_element_op};
input_right_pads, }
in_element_op,
wei_element_op, auto MakeInvoker() { return Invoker{gridwise_gemm}; }
out_element_op};
} std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
auto MakeInvoker() { return Invoker{gridwise_gemm}; } const void* p_wei_grid,
void* p_out_grid,
std::unique_ptr<BaseArgument> const void* p_bias_grid,
MakeArgumentPointer(const void* p_in_grid, const void* p_add_grid,
const void* p_wei_grid, ck::index_t N,
void* p_out_grid, ck::index_t K,
const void* p_bias_grid, ck::index_t C,
const void* p_add_grid, std::vector<ck::index_t> input_spatial_lengths,
ck::index_t N, std::vector<ck::index_t> filter_spatial_lengths,
ck::index_t K, std::vector<ck::index_t> output_spatial_lengths,
ck::index_t C, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> output_spatial_lengths, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> conv_filter_strides, InElementwiseOperation in_element_op,
std::vector<ck::index_t> conv_filter_dilations, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> input_left_pads, OutElementwiseOperation out_element_op) override
std::vector<ck::index_t> input_right_pads, {
InElementwiseOperation in_element_op, return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
WeiElementwiseOperation wei_element_op, static_cast<const WeiDataType*>(p_wei_grid),
OutElementwiseOperation out_element_op) override static_cast<OutDataType*>(p_out_grid),
{ static_cast<const BiasDataType*>(p_bias_grid),
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), static_cast<const AddDataType*>(p_add_grid),
static_cast<const WeiDataType*>(p_wei_grid), N,
static_cast<OutDataType*>(p_out_grid), K,
static_cast<const BiasDataType*>(p_bias_grid), C,
static_cast<const AddDataType*>(p_add_grid), input_spatial_lengths,
N, filter_spatial_lengths,
K, output_spatial_lengths,
C, conv_filter_strides,
input_spatial_lengths, conv_filter_dilations,
filter_spatial_lengths, input_left_pads,
output_spatial_lengths, input_right_pads,
conv_filter_strides, in_element_op,
conv_filter_dilations, wei_element_op,
input_left_pads, out_element_op);
input_right_pads, }
in_element_op,
wei_element_op, std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
out_element_op); {
} return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ std::string GetTypeString() const override
return std::make_unique<Invoker>(Invoker{gridwise_gemm}); {
} auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
std::string GetTypeString() const override if(is_local_buffer)
{ return "L";
auto str = std::stringstream(); else
auto string_local_buffer = [](bool is_local_buffer) { return "G";
if(is_local_buffer) };
return "L"; // clang-format off
else str << "DeviceConv" << std::to_string(NumDimSpatial)
return "G"; << "DFwd_BAA_Avx2_NHWC_YXCK"
}; <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
// clang-format off <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
str << "DeviceConv" << std::to_string(NumDimSpatial) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "DFwd_BAA_Avx2_NHWC_YXCK" << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) << "_TT" << MPerThread << "x" << NPerThread
<<"_KS"<< static_cast<int>(GemmKSpecialization) << "_A" << string_local_buffer(UseALocalBuffer)
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec) << "_B" << string_local_buffer(UseBLocalBuffer)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block << "_C" << string_local_buffer(UseCLocalBuffer)
<< "_TT" << MPerThread << "x" << NPerThread ;
<< "_A" << string_local_buffer(UseALocalBuffer) if constexpr (!std::is_same<OutElementwiseOperation,
<< "_B" << string_local_buffer(UseBLocalBuffer) ck::tensor_operation::cpu::element_wise::PassThrough>::value)
<< "_C" << string_local_buffer(UseCLocalBuffer) {
; str << "_" << OutElementwiseOperation::Name();
if constexpr (!std::is_same<OutElementwiseOperation, }
ck::tensor_operation::cpu::element_wise::PassThrough>::value) // clang-format on
{
str << "_" << OutElementwiseOperation::Name(); return str.str();
} }
// clang-format on };
return str.str(); } // namespace device
} } // namespace cpu
}; } // namespace tensor_operation
} // namespace ck
} // namespace device
} // namespace cpu #endif
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -352,7 +352,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -352,7 +352,8 @@ struct GridwiseGemmAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
...@@ -495,7 +496,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -495,7 +496,8 @@ struct GridwiseGemmAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
......
...@@ -378,7 +378,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -378,7 +378,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
...@@ -533,7 +534,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -533,7 +534,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{},
dynamic_tunable.gemm_k_spec);
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
......
...@@ -210,6 +210,69 @@ void memcpy32_avx2_with_extra_2src(void* dst, ...@@ -210,6 +210,69 @@ void memcpy32_avx2_with_extra_2src(void* dst,
} }
} }
template <typename ElementwiseOp>
void memcpy32_avx2_with_extra_1src(void* dst,
const void* src,
const void* src_aux,
const ck::index_t n,
const ElementwiseOp& element_op)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
const float* p_src_aux = reinterpret_cast<const float*>(src_aux);
while(i_n >= 16)
{
_mm256_storeu_ps(
p_dst + 0,
element_op.Apply(_mm256_loadu_ps(p_src + 0), _mm256_loadu_ps(p_src_aux + 0)));
_mm256_storeu_ps(
p_dst + 8,
element_op.Apply(_mm256_loadu_ps(p_src + 8), _mm256_loadu_ps(p_src_aux + 8)));
p_dst += 16;
p_src += 16;
p_src_aux += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst,
element_op.Apply(_mm256_loadu_ps(p_src), _mm256_loadu_ps(p_src_aux)));
p_dst += 8;
p_src += 8;
p_src_aux += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, element_op.Apply(_mm_loadu_ps(p_src), _mm_loadu_ps(p_src_aux)));
p_dst += 4;
p_src += 4;
p_src_aux += 4;
}
if(i_n & 2)
{
#if defined(__GNUC__) && !defined(__clang__) && !defined(__llvm__)
__m128i s = _mm_loadu_si64(p_src);
__m128i s1 = _mm_loadu_si64(p_src_aux);
__m128 v =
element_op.Apply(*reinterpret_cast<__m128*>(&s), *reinterpret_cast<__m128*>(&s1));
_mm_storeu_si64(p_dst, *reinterpret_cast<__m128i*>(&v));
#else
_mm_storeu_si64(p_dst, element_op.Apply(_mm_loadu_si64(p_src), _mm_loadu_si64(p_src_aux)));
#endif
p_dst += 2;
p_src += 2;
p_src_aux += 2;
}
if(i_n & 1)
{
*p_dst = element_op.Apply(*p_src, *p_src_aux);
}
}
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n) inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
...@@ -324,8 +387,7 @@ template <typename SrcData, ...@@ -324,8 +387,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -336,8 +398,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -336,8 +398,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
const Index&, const Index&,
const DstDesc&, const DstDesc&,
const Index&, const Index&,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op,
: element_op_(element_op) const ConvolutionForwardGemmKSpecialization_t& gemm_k_spec)
: element_op_(element_op), gemm_k_spec_(gemm_k_spec)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -630,8 +693,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -630,8 +693,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if constexpr(GemmKSpecialization == if(gemm_k_spec_ == ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{ {
// c % k_per_block == 0, so every time k_per_block here is the same // c % k_per_block == 0, so every time k_per_block here is the same
ck::index_t i_m_itr = m_per_block; ck::index_t i_m_itr = m_per_block;
...@@ -782,8 +844,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -782,8 +844,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
} }
else else
{ {
if constexpr(GemmKSpecialization == if(gemm_k_spec_ == ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{ {
// TODO: branch seems weird // TODO: branch seems weird
...@@ -827,6 +888,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -827,6 +888,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
private: private:
const ElementwiseOperation element_op_; const ElementwiseOperation element_op_;
const ConvolutionForwardGemmKSpecialization_t gemm_k_spec_;
ck::index_t i_n; ck::index_t i_n;
ck::index_t i_c; ck::index_t i_c;
...@@ -875,8 +937,7 @@ template <typename SrcData, ...@@ -875,8 +937,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -1096,8 +1157,7 @@ template <typename SrcData, ...@@ -1096,8 +1157,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8 struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -1283,8 +1343,7 @@ template <typename SrcData, ...@@ -1283,8 +1343,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -1415,8 +1474,7 @@ template <typename SrcData, ...@@ -1415,8 +1474,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization>
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
......
...@@ -49,17 +49,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -49,17 +49,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
......
...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
......
#include <stdlib.h> #include <stdlib.h>
#include <utility> #include <utility>
#include "config.hpp" #include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp" #include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
namespace device_conv2d_fwd_avx2_instance { namespace device_conv2d_fwd_avx2_instance {
using InType = float; using InType = float;
using WeiType = float; using WeiType = float;
using OutType = float; using OutType = float;
using AccType = float; using AccType = float;
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu; using Relu = ck::tensor_operation::cpu::element_wise::Relu;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 = static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop = static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC = static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK; static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false) DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c( void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true) DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt( void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true), // DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true) DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu( void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false) DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu( void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true) DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu( void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances) std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true), // DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true), DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true) DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on // clang-format on
)); ));
} }
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
......
...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -42,17 +42,17 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
......
#include <stdlib.h> #include <stdlib.h>
#include <utility> #include <utility>
#include "config.hpp" #include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp" #include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance { namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float; using InType = float;
using WeiType = float; using WeiType = float;
using OutType = float; using OutType = float;
using AccType = float; using AccType = float;
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 = static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop = static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC = static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK; static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on // clang-format on
)); ));
} }
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment