Commit 71254ddd authored by carlushuang's avatar carlushuang
Browse files

optimize multi-thread case by support not using LocalA/LocalB

parent dc536427
...@@ -41,6 +41,10 @@ struct BlockwiseGemmAvx2_MxN ...@@ -41,6 +41,10 @@ struct BlockwiseGemmAvx2_MxN
using IndexB = MultiIndex<nDimB>; using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>; using IndexC = MultiIndex<nDimC>;
using ASliceLengths = MultiIndex<nDimA>;
using BSliceLengths = MultiIndex<nDimB>;
using CSliceLengths = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{})); using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{})); using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{})); using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
...@@ -89,6 +93,7 @@ struct BlockwiseGemmAvx2_MxN ...@@ -89,6 +93,7 @@ struct BlockwiseGemmAvx2_MxN
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
} }
#if 0
static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
...@@ -134,6 +139,7 @@ struct BlockwiseGemmAvx2_MxN ...@@ -134,6 +139,7 @@ struct BlockwiseGemmAvx2_MxN
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
} }
} }
#endif
static ck::index_t static ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t)
...@@ -175,14 +181,17 @@ struct BlockwiseGemmAvx2_MxN ...@@ -175,14 +181,17 @@ struct BlockwiseGemmAvx2_MxN
static void Run(const ABlockDesc& a_block_desc, static void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const IndexA& /* a_origin */, const IndexA& /* a_origin */,
const ASliceLengths& a_slice_length,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
const IndexB& /* b_origin */, const IndexB& /* b_origin */,
const BSliceLengths& b_slice_length,
const CDesc& c_desc, const CDesc& c_desc,
CBuffer& c_buf, CBuffer& c_buf,
const IndexC& /* c_origin */, const IndexC& /* c_origin */,
const CSliceLengths& c_slice_length,
bool is_accumulate_c = true) bool is_accumulate_c = true)
{ {
...@@ -192,9 +201,9 @@ struct BlockwiseGemmAvx2_MxN ...@@ -192,9 +201,9 @@ struct BlockwiseGemmAvx2_MxN
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc); // printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const auto k_per_block = GetKPerBlock(a_block_desc); const auto k_per_block = a_slice_length[Number<1>{}];
const auto m_per_block = GetMPerBlock(a_block_desc); const auto m_per_block = c_slice_length[Number<0>{}];
const auto n_per_block = GetNPerBlock(b_block_desc); const auto n_per_block = c_slice_length[Number<1>{}];
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr; const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr; const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
...@@ -206,6 +215,9 @@ struct BlockwiseGemmAvx2_MxN ...@@ -206,6 +215,9 @@ struct BlockwiseGemmAvx2_MxN
param.alpha = 1.0f; // TODO param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0; param.accmulate_c = is_accumulate_c ? 1 : 0;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc,
// m_per_block, n_per_block, k_per_block);
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value) if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{ {
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread) for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
......
...@@ -108,20 +108,41 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -108,20 +108,41 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr auto GetInputBlockDescriptor() static constexpr auto GetInputBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
} }
static constexpr auto GetWeightBlockDescriptor() static constexpr auto GetWeightBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple( if constexpr(UseBLocalBuffer)
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), {
KPerBlock, return make_naive_tensor_descriptor_packed(make_tuple(
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
} }
static constexpr auto GetOutputBlockDescriptor() static constexpr auto GetOutputBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
} }
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
...@@ -564,7 +585,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -564,7 +585,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc, AGridDesc,
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
false, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization,
GemmKSpecialization>; GemmKSpecialization>;
...@@ -575,7 +596,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -575,7 +596,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BGridDesc, BGridDesc,
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
false, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization,
GemmKSpecialization>; GemmKSpecialization>;
...@@ -786,12 +807,23 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -786,12 +807,23 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
if constexpr(GemmKSpecialization == if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % KPerBlock == 0))
return false; return false;
} }
if constexpr((!UseALocalBuffer || !UseBLocalBuffer) &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return false;
}
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
......
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP #ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#include "device.hpp" #include "device.hpp"
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp" #include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp" #include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp" #include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp" #include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp" #include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization, ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3) ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer> bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{ {
using DeviceOp = DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K; using DeviceOp = DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K;
using ADataType = InDataType; using ADataType = InDataType;
using BDataType = WeiDataType; using BDataType = WeiDataType;
using CDataType = OutDataType; using CDataType = OutDataType;
using AElementwiseOperation = InElementwiseOperation; using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation; using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation; using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial; static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() static constexpr auto GetBlockMNKAccessOrder()
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver || if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK) BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{}; return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN) else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{}; return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder()); using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
{ {
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch< return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{}; NonTemporalStore>{};
} }
else if constexpr(MPerThread == 6 && NPerThread == 16) else if constexpr(MPerThread == 6 && NPerThread == 16)
{ {
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch< return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{}; NonTemporalStore>{};
} }
else else
{ {
// static_assert(false, "invalid Mr/Nr"); // static_assert(false, "invalid Mr/Nr");
} }
} }
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch()); using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor() static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); return make_naive_tensor_descriptor_packed(make_tuple(gemm_n / 8, gemm_k, 8));
} }
static constexpr auto GetWeightBlockDescriptor() static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
{ {
return make_naive_tensor_descriptor_packed(make_tuple( const auto out_gemm_m_n_grid_desc =
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); return out_gemm_m_n_grid_desc;
} }
static constexpr auto GetOutputBlockDescriptor() template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
{ static auto GetInputTensorDescriptor(ck::index_t N,
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); ck::index_t C,
} ck::index_t gemm_m,
ck::index_t gemm_k,
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) const std::vector<ck::index_t>& input_spatial_lengths,
{ const std::vector<ck::index_t>& filter_spatial_lengths,
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n / 8, gemm_k, 8)); const std::vector<ck::index_t>& output_spatial_lengths,
} const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n) const std::vector<ck::index_t>& input_left_pads,
{ const std::vector<ck::index_t>& input_right_pads)
const auto out_gemm_m_n_grid_desc = {
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n)); const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
return out_gemm_m_n_grid_desc; const index_t ConvStrideW = conv_filter_strides[0];
}
if constexpr(ConvForwardSpecialization ==
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
static auto GetInputTensorDescriptor(ck::index_t N, {
ck::index_t C, const auto in_gemm_m_k_grid_desc =
ck::index_t gemm_m, make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
ck::index_t gemm_k,
const std::vector<ck::index_t>& input_spatial_lengths, return in_gemm_m_k_grid_desc;
const std::vector<ck::index_t>& filter_spatial_lengths, }
const std::vector<ck::index_t>& output_spatial_lengths, else if constexpr(ConvForwardSpecialization ==
const std::vector<ck::index_t>& conv_filter_strides, ConvolutionForwardSpecialization_t::Filter1x1Pad0)
const std::vector<ck::index_t>& conv_filter_dilations, {
const std::vector<ck::index_t>& input_left_pads, const auto in_n_wi_c_grid_desc =
const std::vector<ck::index_t>& input_right_pads) make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
{
const index_t Wi = input_spatial_lengths[0]; const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
const index_t Wo = output_spatial_lengths[0]; in_n_wi_c_grid_desc,
const index_t ConvStrideW = conv_filter_strides[0]; make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
if constexpr(ConvForwardSpecialization == make_pass_through_transform(C)),
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
{ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
return in_gemm_m_k_grid_desc; make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
} make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
else if constexpr(ConvForwardSpecialization == make_tuple(Sequence<0>{}, Sequence<1>{}));
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{ return in_gemm_m_k_grid_desc;
const auto in_n_wi_c_grid_desc = }
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); else
{
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( const index_t X = filter_spatial_lengths[0];
in_n_wi_c_grid_desc, const index_t ConvDilationW = conv_filter_dilations[0];
make_tuple(make_pass_through_transform(N), const index_t InLeftPadW = input_left_pads[0];
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), const index_t InRightPadW = input_right_pads[0];
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), const auto in_n_wi_c_grid_desc =
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc, in_n_wi_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), make_tuple(make_pass_through_transform(N),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
return in_gemm_m_k_grid_desc; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
{ in_n_wip_c_grid_desc,
const index_t X = filter_spatial_lengths[0]; make_tuple(
const index_t ConvDilationW = conv_filter_dilations[0]; make_pass_through_transform(N),
const index_t InLeftPadW = input_left_pads[0]; make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
const index_t InRightPadW = input_right_pads[0]; make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
const auto in_n_wi_c_grid_desc = make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_gemm_m_k_grid_desc =
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
in_n_wi_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_tuple(make_pass_through_transform(N), make_merge_transform(make_tuple(X, C))),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); return in_gemm_m_k_grid_desc;
}
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( }
in_n_wip_c_grid_desc,
make_tuple( template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
make_pass_through_transform(N), static auto GetInputTensorDescriptor(ck::index_t N,
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), ck::index_t C,
make_pass_through_transform(C)), ck::index_t gemm_m,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), ck::index_t gemm_k,
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const auto in_gemm_m_k_grid_desc = const std::vector<ck::index_t>& output_spatial_lengths,
transform_tensor_descriptor(in_n_x_wo_c_grid_desc, const std::vector<ck::index_t>& conv_filter_strides,
make_tuple(make_merge_transform(make_tuple(N, Wo)), const std::vector<ck::index_t>& conv_filter_dilations,
make_merge_transform(make_tuple(X, C))), const std::vector<ck::index_t>& input_left_pads,
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), const std::vector<ck::index_t>& input_right_pads)
make_tuple(Sequence<0>{}, Sequence<1>{})); {
const index_t Hi = input_spatial_lengths[0];
return in_gemm_m_k_grid_desc; const index_t Wi = input_spatial_lengths[1];
}
} const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N, const index_t ConvStrideH = conv_filter_strides[0];
ck::index_t C, const index_t ConvStrideW = conv_filter_strides[1];
ck::index_t gemm_m,
ck::index_t gemm_k, if constexpr(ConvForwardSpecialization ==
const std::vector<ck::index_t>& input_spatial_lengths, ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
const std::vector<ck::index_t>& filter_spatial_lengths, {
const std::vector<ck::index_t>& output_spatial_lengths, const auto in_gemm_m_k_grid_desc =
const std::vector<ck::index_t>& conv_filter_strides, make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads, return in_gemm_m_k_grid_desc;
const std::vector<ck::index_t>& input_right_pads) }
{ else if constexpr(ConvForwardSpecialization ==
const index_t Hi = input_spatial_lengths[0]; ConvolutionForwardSpecialization_t::Filter1x1Pad0)
const index_t Wi = input_spatial_lengths[1]; {
const auto in_n_hi_wi_c_grid_desc =
const index_t Ho = output_spatial_lengths[0]; make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const index_t Wo = output_spatial_lengths[1];
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
const index_t ConvStrideH = conv_filter_strides[0]; in_n_hi_wi_c_grid_desc,
const index_t ConvStrideW = conv_filter_strides[1]; make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
if constexpr(ConvForwardSpecialization == make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) make_pass_through_transform(C)),
{ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const auto in_gemm_m_k_grid_desc = make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
const auto in_gemm_m_k_grid_desc =
return in_gemm_m_k_grid_desc; transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
} make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
else if constexpr(ConvForwardSpecialization == make_pass_through_transform(C)),
ConvolutionForwardSpecialization_t::Filter1x1Pad0) make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
{ make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); return in_gemm_m_k_grid_desc;
}
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( else
in_n_hi_wi_c_grid_desc, {
make_tuple(make_pass_through_transform(N), const index_t Y = filter_spatial_lengths[0];
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), const index_t X = filter_spatial_lengths[1];
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)), const index_t ConvDilationH = conv_filter_dilations[0];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const index_t ConvDilationW = conv_filter_dilations[1];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t InLeftPadH = input_left_pads[0];
const auto in_gemm_m_k_grid_desc = const index_t InLeftPadW = input_left_pads[1];
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), const index_t InRightPadH = input_right_pads[0];
make_pass_through_transform(C)), const index_t InRightPadW = input_right_pads[1];
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
return in_gemm_m_k_grid_desc;
} const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
else in_n_hi_wi_c_grid_desc,
{ make_tuple(make_pass_through_transform(N),
const index_t Y = filter_spatial_lengths[0]; make_pad_transform(Hi, InLeftPadH, InRightPadH),
const index_t X = filter_spatial_lengths[1]; make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
const index_t ConvDilationH = conv_filter_dilations[0]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const index_t ConvDilationW = conv_filter_dilations[1]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t InLeftPadH = input_left_pads[0]; const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
const index_t InLeftPadW = input_left_pads[1]; in_n_hip_wip_c_grid_desc,
make_tuple(
const index_t InRightPadH = input_right_pads[0]; make_pass_through_transform(N),
const index_t InRightPadW = input_right_pads[1]; make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
const auto in_n_hi_wi_c_grid_desc = make_pass_through_transform(C)),
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, const auto in_gemm_m_k_grid_desc =
make_tuple(make_pass_through_transform(N), transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_merge_transform(make_tuple(Y, X, C))),
make_pass_through_transform(C)), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return in_gemm_m_k_grid_desc;
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( }
in_n_hip_wip_c_grid_desc, }
make_tuple(
make_pass_through_transform(N), template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), static auto GetInputTensorDescriptor(ck::index_t N,
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), ck::index_t C,
make_pass_through_transform(C)), ck::index_t gemm_m,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), ck::index_t gemm_k,
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); ck::index_t gemm_m_pad,
const std::vector<ck::index_t>& input_spatial_lengths,
const auto in_gemm_m_k_grid_desc = const std::vector<ck::index_t>& filter_spatial_lengths,
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, const std::vector<ck::index_t>& output_spatial_lengths,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), const std::vector<ck::index_t>& conv_filter_strides,
make_merge_transform(make_tuple(Y, X, C))), const std::vector<ck::index_t>& conv_filter_dilations,
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), const std::vector<ck::index_t>& input_left_pads,
make_tuple(Sequence<0>{}, Sequence<1>{})); const std::vector<ck::index_t>& input_right_pads)
{
return in_gemm_m_k_grid_desc; const index_t Di = input_spatial_lengths[0];
} const index_t Hi = input_spatial_lengths[1];
} const index_t Wi = input_spatial_lengths[2];
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> const index_t Do = output_spatial_lengths[0];
static auto GetInputTensorDescriptor(ck::index_t N, const index_t Ho = output_spatial_lengths[1];
ck::index_t C, const index_t Wo = output_spatial_lengths[2];
ck::index_t gemm_m,
ck::index_t gemm_k, const index_t ConvStrideD = conv_filter_strides[0];
ck::index_t gemm_m_pad, const index_t ConvStrideH = conv_filter_strides[1];
const std::vector<ck::index_t>& input_spatial_lengths, const index_t ConvStrideW = conv_filter_strides[2];
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths, if constexpr(ConvForwardSpecialization ==
const std::vector<ck::index_t>& conv_filter_strides, ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
const std::vector<ck::index_t>& conv_filter_dilations, {
const std::vector<ck::index_t>& input_left_pads, const auto in_gemm_m_k_grid_desc =
const std::vector<ck::index_t>& input_right_pads) make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
{
const index_t Di = input_spatial_lengths[0]; return in_gemm_m_k_grid_desc;
const index_t Hi = input_spatial_lengths[1]; }
const index_t Wi = input_spatial_lengths[2]; else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
const index_t Do = output_spatial_lengths[0]; {
const index_t Ho = output_spatial_lengths[1]; const auto in_n_di_hi_wi_c_grid_desc =
const index_t Wo = output_spatial_lengths[2]; make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const index_t ConvStrideD = conv_filter_strides[0]; const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
const index_t ConvStrideH = conv_filter_strides[1]; in_n_di_hi_wi_c_grid_desc,
const index_t ConvStrideW = conv_filter_strides[2]; make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
if constexpr(ConvForwardSpecialization == make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
{ make_pass_through_transform(C)),
const auto in_gemm_m_k_grid_desc = make_tuple(
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
return in_gemm_m_k_grid_desc; Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
}
else if constexpr(ConvForwardSpecialization == const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
ConvolutionForwardSpecialization_t::Filter1x1Pad0) in_n_do_ho_wo_c_grid_desc,
{ make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
const auto in_n_di_hi_wi_c_grid_desc = make_pass_through_transform(C)),
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc, return in_gemm_m_k_grid_desc;
make_tuple(make_pass_through_transform(N), }
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), else
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), {
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), const index_t Z = filter_spatial_lengths[0];
make_pass_through_transform(C)), const index_t Y = filter_spatial_lengths[1];
make_tuple( const index_t X = filter_spatial_lengths[2];
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( const index_t ConvDilationD = conv_filter_dilations[0];
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc, const index_t InLeftPadD = input_left_pads[0];
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), const index_t InLeftPadH = input_left_pads[1];
make_pass_through_transform(C)), const index_t InLeftPadW = input_left_pads[2];
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
return in_gemm_m_k_grid_desc; const index_t InRightPadW = input_right_pads[2];
}
else const auto in_n_di_hi_wi_c_grid_desc =
{ make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1]; const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
const index_t X = filter_spatial_lengths[2]; in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
const index_t ConvDilationD = conv_filter_dilations[0]; make_pad_transform(Di, InLeftPadD, InRightPadD),
const index_t ConvDilationH = conv_filter_dilations[1]; make_pad_transform(Hi, InLeftPadH, InRightPadH),
const index_t ConvDilationW = conv_filter_dilations[2]; make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
const index_t InLeftPadD = input_left_pads[0]; make_tuple(
const index_t InLeftPadH = input_left_pads[1]; Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
const index_t InLeftPadW = input_left_pads[2]; make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1]; const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
const index_t InRightPadW = input_right_pads[2]; in_n_hip_wip_c_grid_desc,
make_tuple(
const auto in_n_di_hi_wi_c_grid_desc = make_pass_through_transform(N),
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
in_n_di_hi_wi_c_grid_desc, make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N), make_tuple(
make_pad_transform(Di, InLeftPadD, InRightPadD), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_tuple(Sequence<0>{},
make_pad_transform(Wi, InLeftPadW, InRightPadW), Sequence<1, 2>{},
make_pass_through_transform(C)), Sequence<3, 4>{},
make_tuple( Sequence<5, 6>{},
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<7>{}));
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
in_n_hip_wip_c_grid_desc, make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple( make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_pass_through_transform(N), make_tuple(Sequence<0>{}, Sequence<1>{}));
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), return in_gemm_m_k_grid_desc;
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), }
make_pass_through_transform(C)), }
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
make_tuple(Sequence<0>{}, {
Sequence<1, 2>{}, return N * std::accumulate(std::begin(output_spatial_lengths),
Sequence<3, 4>{}, std::end(output_spatial_lengths),
Sequence<5, 6>{}, 1,
Sequence<7>{})); std::multiplies<ck::index_t>());
}
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc, static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), {
make_merge_transform(make_tuple(Z, Y, X, C))), return C * std::accumulate(std::begin(filter_spatial_lengths),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), std::end(filter_spatial_lengths),
make_tuple(Sequence<0>{}, Sequence<1>{})); 1,
std::multiplies<ck::index_t>());
return in_gemm_m_k_grid_desc; }
}
} static index_t GetGemmN(ck::index_t K)
{
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths) // return ck::math::integer_least_multiple(K,
{ // ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return N * std::accumulate(std::begin(output_spatial_lengths), return K;
std::end(output_spatial_lengths), }
1,
std::multiplies<ck::index_t>()); static auto MakeABCGridDescriptor(ck::index_t N,
} ck::index_t K,
ck::index_t C,
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths) std::vector<ck::index_t> input_spatial_lengths,
{ std::vector<ck::index_t> filter_spatial_lengths,
return C * std::accumulate(std::begin(filter_spatial_lengths), std::vector<ck::index_t> output_spatial_lengths,
std::end(filter_spatial_lengths), std::vector<ck::index_t> conv_filter_strides,
1, std::vector<ck::index_t> conv_filter_dilations,
std::multiplies<ck::index_t>()); std::vector<ck::index_t> input_left_pads,
} std::vector<ck::index_t> input_right_pads)
{
static index_t GetGemmN(ck::index_t K) using namespace ck;
{
// return ck::math::integer_least_multiple(K, const index_t GemmM = GetGemmM(N, output_spatial_lengths);
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); const index_t GemmN = GetGemmN(K);
return K; const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
}
// A:
static auto MakeABCGridDescriptor(ck::index_t N, const auto in_gemm_m_k_grid_desc =
ck::index_t K, GetInputTensorDescriptor<NumDimSpatial>(N,
ck::index_t C, C,
std::vector<ck::index_t> input_spatial_lengths, GemmM,
std::vector<ck::index_t> filter_spatial_lengths, GemmK,
std::vector<ck::index_t> output_spatial_lengths, input_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_dilations, output_spatial_lengths,
std::vector<ck::index_t> input_left_pads, conv_filter_strides,
std::vector<ck::index_t> input_right_pads) conv_filter_dilations,
{ input_left_pads,
using namespace ck; input_right_pads);
// B:
const index_t GemmM = GetGemmM(N, output_spatial_lengths); const auto wei_gemm_n0_k_n1_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
const index_t GemmN = GetGemmN(K); // C:
const index_t GemmK = GetGemmK(C, filter_spatial_lengths); const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
// A: return make_tuple(
const auto in_gemm_m_k_grid_desc = in_gemm_m_k_grid_desc, wei_gemm_n0_k_n1_grid_desc, out_gemm_m_n_grid_desc);
GetInputTensorDescriptor<NumDimSpatial>(N, }
C,
GemmM, template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
GemmK, static auto GetABCGridDesc()
input_spatial_lengths, {
filter_spatial_lengths, return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
output_spatial_lengths, }
conv_filter_strides,
conv_filter_dilations, template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
input_left_pads, static auto GetABCGridDesc()
input_right_pads); {
// B: return MakeABCGridDescriptor(
const auto wei_gemm_n0_k_n1_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
// C: }
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
return make_tuple( static auto GetABCGridDesc()
in_gemm_m_k_grid_desc, wei_gemm_n0_k_n1_grid_desc, out_gemm_m_n_grid_desc); {
} return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> }
static auto GetABCGridDesc()
{ using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
} using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static auto GetABCGridDesc()
{ static constexpr auto GetInputBlockDescriptor()
return MakeABCGridDescriptor( {
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); if constexpr(UseALocalBuffer)
} {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> }
static auto GetABCGridDesc() else
{ {
return MakeABCGridDescriptor( return AGridDesc{};
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); }
} }
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); static constexpr auto GetWeightBlockDescriptor()
{
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; if constexpr(UseBLocalBuffer)
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; {
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
// static constexpr bool UseCLocalBuffer = false; KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
using AThreadwiseCopy = }
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC< else
InDataType, {
InDataType, return BGridDesc{};
AGridDesc, }
decltype(GetInputBlockDescriptor()), }
InElementwiseOperation,
false, static constexpr auto GetOutputBlockDescriptor()
ConvForwardSpecialization, {
GemmKSpecialization>; if constexpr(UseCLocalBuffer)
{
using BThreadwiseCopy = return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8< }
WeiDataType, else
WeiDataType, {
BGridDesc, return CGridDesc{};
decltype(GetWeightBlockDescriptor()), }
WeiElementwiseOperation, }
false,
ConvForwardSpecialization, // static constexpr bool UseCLocalBuffer = false;
GemmKSpecialization>;
using AThreadwiseCopy =
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN< ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
OutDataType, InDataType,
OutDataType, InDataType,
CGridDesc, AGridDesc,
decltype(GetOutputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
OutElementwiseOperation, InElementwiseOperation,
!UseCLocalBuffer, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization,
GemmKSpecialization>; GemmKSpecialization>;
using GridwiseGemm = using BThreadwiseCopy =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8<
WeiDataType, // WeiDataType, WeiDataType,
OutDataType, // OutDataType, WeiDataType,
AGridDesc, // AGridDesc, BGridDesc,
BGridDesc, // BGridDesc, decltype(GetWeightBlockDescriptor()),
CGridDesc, // CGridDesc, WeiElementwiseOperation,
AElementwiseOperation, // AElementwiseOperation, !UseBLocalBuffer,
BElementwiseOperation, // BElementwiseOperation, ConvForwardSpecialization,
CElementwiseOperation, // CElementwiseOperation, GemmKSpecialization>;
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock, using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
KPerBlock, // KPerBlock, OutDataType,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, OutDataType,
AThreadwiseCopy, // AThreadwiseCopy CGridDesc,
BThreadwiseCopy, // BThreadwiseCopy decltype(GetOutputBlockDescriptor()),
CThreadwiseCopy, // CThreadwiseCopy OutElementwiseOperation,
BlockMNKAccessOrder, // BlockMNKAccessOrder, !UseCLocalBuffer,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ConvForwardSpecialization,
UseALocalBuffer, // UseALocalBuffer GemmKSpecialization>;
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer using GridwiseGemm =
>; ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType,
// Argument OutDataType, // OutDataType,
struct Argument : public BaseArgument AGridDesc, // AGridDesc,
{ BGridDesc, // BGridDesc,
Argument(const InDataType* p_in_grid, CGridDesc, // CGridDesc,
const WeiDataType* p_wei_grid, AElementwiseOperation, // AElementwiseOperation,
OutDataType* p_out_grid, BElementwiseOperation, // BElementwiseOperation,
ck::index_t N, CElementwiseOperation, // CElementwiseOperation,
ck::index_t K, MPerBlock, // MPerBlock,
ck::index_t C, NPerBlock, // NPerBlock,
std::vector<ck::index_t> input_spatial_lengths, KPerBlock, // KPerBlock,
std::vector<ck::index_t> filter_spatial_lengths, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
std::vector<ck::index_t> output_spatial_lengths, AThreadwiseCopy, // AThreadwiseCopy
std::vector<ck::index_t> conv_filter_strides, BThreadwiseCopy, // BThreadwiseCopy
std::vector<ck::index_t> conv_filter_dilations, CThreadwiseCopy, // CThreadwiseCopy
std::vector<ck::index_t> input_left_pads, BlockMNKAccessOrder, // BlockMNKAccessOrder,
std::vector<ck::index_t> input_right_pads, ck::Sequence<0, 1>, // ThreadMNAccessOrder
InElementwiseOperation in_element_op, UseALocalBuffer, // UseALocalBuffer
WeiElementwiseOperation wei_element_op, UseBLocalBuffer, // UseBLocalBuffer
OutElementwiseOperation out_element_op) UseCLocalBuffer // UseCLocalBuffer
: p_a_grid_{p_in_grid}, >;
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid}, // Argument
a_grid_desc_{}, struct Argument : public BaseArgument
b_grid_desc_{}, {
c_grid_desc_{}, Argument(const InDataType* p_in_grid,
a_element_op_{in_element_op}, const WeiDataType* p_wei_grid,
b_element_op_{wei_element_op}, OutDataType* p_out_grid,
c_element_op_{out_element_op}, ck::index_t N,
Conv_N_{N}, ck::index_t K,
Conv_K_{K}, ck::index_t C,
Conv_C_{C}, std::vector<ck::index_t> input_spatial_lengths,
filter_spatial_lengths_{filter_spatial_lengths}, std::vector<ck::index_t> filter_spatial_lengths,
conv_filter_strides_{conv_filter_strides}, std::vector<ck::index_t> output_spatial_lengths,
input_left_pads_{input_left_pads}, std::vector<ck::index_t> conv_filter_strides,
input_right_pads_{input_right_pads} std::vector<ck::index_t> conv_filter_dilations,
{ std::vector<ck::index_t> input_left_pads,
const auto descs = DeviceOp::MakeABCGridDescriptor(N, std::vector<ck::index_t> input_right_pads,
K, InElementwiseOperation in_element_op,
C, WeiElementwiseOperation wei_element_op,
input_spatial_lengths, OutElementwiseOperation out_element_op)
filter_spatial_lengths, : p_a_grid_{p_in_grid},
output_spatial_lengths, p_b_grid_{p_wei_grid},
conv_filter_strides, p_c_grid_{p_out_grid},
conv_filter_dilations, a_grid_desc_{},
input_left_pads, b_grid_desc_{},
input_right_pads); c_grid_desc_{},
a_grid_desc_ = descs[I0]; a_element_op_{in_element_op},
b_grid_desc_ = descs[I1]; b_element_op_{wei_element_op},
c_grid_desc_ = descs[I2]; c_element_op_{out_element_op},
} Conv_N_{N},
Conv_K_{K},
// private: Conv_C_{C},
const ADataType* p_a_grid_; filter_spatial_lengths_{filter_spatial_lengths},
const BDataType* p_b_grid_; conv_filter_strides_{conv_filter_strides},
CDataType* p_c_grid_; input_left_pads_{input_left_pads},
AGridDesc a_grid_desc_; input_right_pads_{input_right_pads}
BGridDesc b_grid_desc_; {
CGridDesc c_grid_desc_; const auto descs = DeviceOp::MakeABCGridDescriptor(N,
K,
AElementwiseOperation a_element_op_; C,
BElementwiseOperation b_element_op_; input_spatial_lengths,
CElementwiseOperation c_element_op_; filter_spatial_lengths,
// for checking IsSupportedArgument() output_spatial_lengths,
index_t Conv_N_; conv_filter_strides,
index_t Conv_K_; conv_filter_dilations,
index_t Conv_C_; input_left_pads,
std::vector<index_t> filter_spatial_lengths_; input_right_pads);
std::vector<index_t> conv_filter_strides_; a_grid_desc_ = descs[I0];
std::vector<index_t> input_left_pads_; b_grid_desc_ = descs[I1];
std::vector<index_t> input_right_pads_; c_grid_desc_ = descs[I2];
}; }
// Invoker // private:
struct Invoker : public BaseInvoker const ADataType* p_a_grid_;
{ const BDataType* p_b_grid_;
using Argument = DeviceOp::Argument; CDataType* p_c_grid_;
AGridDesc a_grid_desc_;
float Run(const Argument& arg, BGridDesc b_grid_desc_;
const StreamConfig& stream_config = StreamConfig{}, CGridDesc c_grid_desc_;
int nrepeat = 1)
{ AElementwiseOperation a_element_op_;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) BElementwiseOperation b_element_op_;
{ CElementwiseOperation c_element_op_;
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); // for checking IsSupportedArgument()
} index_t Conv_N_;
index_t Conv_K_;
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm, std::vector<index_t> conv_filter_strides_;
InDataType, std::vector<index_t> input_left_pads_;
WeiDataType, std::vector<index_t> input_right_pads_;
OutDataType, };
AGridDesc,
BGridDesc, // Invoker
CGridDesc, struct Invoker : public BaseInvoker
AElementwiseOperation, {
BElementwiseOperation, using Argument = DeviceOp::Argument;
CElementwiseOperation>;
float Run(const Argument& arg,
float ave_time = 0; const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
if(nrepeat != 1) {
ave_time = launch_and_time_cpu_kernel(kernel, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
nrepeat, {
arg.p_a_grid_, throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
arg.p_b_grid_, }
arg.p_c_grid_,
arg.a_grid_desc_, memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
arg.b_grid_desc_,
arg.c_grid_desc_, const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
arg.a_element_op_, InDataType,
arg.b_element_op_, WeiDataType,
arg.c_element_op_); OutDataType,
AGridDesc,
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the BGridDesc,
// result CGridDesc,
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); AElementwiseOperation,
BElementwiseOperation,
launch_cpu_kernel(kernel, CElementwiseOperation>;
arg.p_a_grid_,
arg.p_b_grid_, float ave_time = 0;
arg.p_c_grid_,
arg.a_grid_desc_, if(nrepeat != 1)
arg.b_grid_desc_, ave_time = launch_and_time_cpu_kernel(kernel,
arg.c_grid_desc_, nrepeat,
arg.a_element_op_, arg.p_a_grid_,
arg.b_element_op_, arg.p_b_grid_,
arg.c_element_op_); arg.p_c_grid_,
arg.a_grid_desc_,
return ave_time; arg.b_grid_desc_,
} arg.c_grid_desc_,
arg.a_element_op_,
float Run(const BaseArgument* p_arg, arg.b_element_op_,
const StreamConfig& stream_config = StreamConfig{}, arg.c_element_op_);
int nrepeat = 1) override
{ // TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat); // result
} memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
};
launch_cpu_kernel(kernel,
static constexpr bool IsValidCompilationParameter() arg.p_a_grid_,
{ arg.p_b_grid_,
// TODO: properly implement this check arg.p_c_grid_,
return true; arg.a_grid_desc_,
} arg.b_grid_desc_,
arg.c_grid_desc_,
static bool IsSupportedArgument(const Argument& arg) arg.a_element_op_,
{ arg.b_element_op_,
if constexpr(ConvForwardSpecialization == arg.c_element_op_);
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ return ave_time;
// check if it's 1x1, stride=1 conv }
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && float Run(const BaseArgument* p_arg,
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && const StreamConfig& stream_config = StreamConfig{},
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) int nrepeat = 1) override
{ {
return false; return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
} }
} };
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) static constexpr bool IsValidCompilationParameter()
{ {
// check if it's 1x1 conv // TODO: properly implement this check
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && return true;
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && }
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{ static bool IsSupportedArgument(const Argument& arg)
return false; {
} if constexpr(ConvForwardSpecialization ==
} ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if constexpr(GemmKSpecialization == // check if it's 1x1, stride=1 conv
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
{ arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
if(!(arg.Conv_C_ % KPerBlock == 0)) arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
return false; arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
} {
return false;
if(!(arg.Conv_K_ % 8 == 0)) }
return false; }
else if constexpr(ConvForwardSpecialization ==
// Gridwise GEMM size ConvolutionForwardSpecialization_t::Filter1x1Pad0)
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); {
} // check if it's 1x1 conv
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
bool IsSupportedArgument(const BaseArgument* p_arg) override arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
{ arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); {
} return false;
}
static auto MakeArgument(const InDataType* p_in_grid, }
const WeiDataType* p_wei_grid,
OutDataType* p_out_grid, if constexpr(GemmKSpecialization ==
ck::index_t N, ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ck::index_t K, ConvForwardSpecialization !=
ck::index_t C, ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
std::vector<ck::index_t> input_spatial_lengths, {
std::vector<ck::index_t> filter_spatial_lengths, if(!(arg.Conv_C_ % KPerBlock == 0))
std::vector<ck::index_t> output_spatial_lengths, return false;
std::vector<ck::index_t> conv_filter_strides, }
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, if(!(arg.Conv_K_ % 8 == 0))
std::vector<ck::index_t> input_right_pads, return false;
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, if constexpr(!UseALocalBuffer &&
OutElementwiseOperation out_element_op) ConvForwardSpecialization !=
{ ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
return Argument{p_in_grid, {
p_wei_grid, // TODO: We can support this in the future, as long as figure out how to express tensor
p_out_grid, // transform
N, return false;
K, }
C,
input_spatial_lengths, // Gridwise GEMM size
filter_spatial_lengths, return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
output_spatial_lengths, }
conv_filter_strides,
conv_filter_dilations, bool IsSupportedArgument(const BaseArgument* p_arg) override
input_left_pads, {
input_right_pads, return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
in_element_op, }
wei_element_op,
out_element_op}; static auto MakeArgument(const InDataType* p_in_grid,
} const WeiDataType* p_wei_grid,
OutDataType* p_out_grid,
static auto MakeInvoker() { return Invoker{}; } ck::index_t N,
ck::index_t K,
std::unique_ptr<BaseArgument> ck::index_t C,
MakeArgumentPointer(const void* p_in_grid, std::vector<ck::index_t> input_spatial_lengths,
const void* p_wei_grid, std::vector<ck::index_t> filter_spatial_lengths,
void* p_out_grid, std::vector<ck::index_t> output_spatial_lengths,
ck::index_t N, std::vector<ck::index_t> conv_filter_strides,
ck::index_t K, std::vector<ck::index_t> conv_filter_dilations,
ck::index_t C, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_spatial_lengths, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> filter_spatial_lengths, InElementwiseOperation in_element_op,
std::vector<ck::index_t> output_spatial_lengths, WeiElementwiseOperation wei_element_op,
std::vector<ck::index_t> conv_filter_strides, OutElementwiseOperation out_element_op)
std::vector<ck::index_t> conv_filter_dilations, {
std::vector<ck::index_t> input_left_pads, return Argument{p_in_grid,
std::vector<ck::index_t> input_right_pads, p_wei_grid,
InElementwiseOperation in_element_op, p_out_grid,
WeiElementwiseOperation wei_element_op, N,
OutElementwiseOperation out_element_op) override K,
{ C,
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), input_spatial_lengths,
static_cast<const WeiDataType*>(p_wei_grid), filter_spatial_lengths,
static_cast<OutDataType*>(p_out_grid), output_spatial_lengths,
N, conv_filter_strides,
K, conv_filter_dilations,
C, input_left_pads,
input_spatial_lengths, input_right_pads,
filter_spatial_lengths, in_element_op,
output_spatial_lengths, wei_element_op,
conv_filter_strides, out_element_op};
conv_filter_dilations, }
input_left_pads,
input_right_pads, static auto MakeInvoker() { return Invoker{}; }
in_element_op,
wei_element_op, std::unique_ptr<BaseArgument>
out_element_op); MakeArgumentPointer(const void* p_in_grid,
} const void* p_wei_grid,
void* p_out_grid,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override ck::index_t N,
{ ck::index_t K,
return std::make_unique<Invoker>(Invoker{}); ck::index_t C,
} std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::string GetTypeString() const override std::vector<ck::index_t> output_spatial_lengths,
{ std::vector<ck::index_t> conv_filter_strides,
auto str = std::stringstream(); std::vector<ck::index_t> conv_filter_dilations,
auto string_local_buffer = [](bool is_local_buffer) { std::vector<ck::index_t> input_left_pads,
if(is_local_buffer) std::vector<ck::index_t> input_right_pads,
return "L"; InElementwiseOperation in_element_op,
else WeiElementwiseOperation wei_element_op,
return "G"; OutElementwiseOperation out_element_op) override
}; {
// clang-format off return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
str << "DeviceConv" << std::to_string(NumDimSpatial) static_cast<const WeiDataType*>(p_wei_grid),
<< "DFwdAvx2_NHWC_KYXCK8" static_cast<OutDataType*>(p_out_grid),
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) N,
<<"_KS"<< static_cast<int>(GemmKSpecialization) K,
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) C,
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock input_spatial_lengths,
<< "_TT" << MPerThread << "x" << NPerThread filter_spatial_lengths,
<< "_A" << string_local_buffer(UseALocalBuffer) output_spatial_lengths,
<< "_B" << string_local_buffer(UseBLocalBuffer) conv_filter_strides,
<< "_C" << string_local_buffer(UseCLocalBuffer) conv_filter_dilations,
; input_left_pads,
if constexpr (!std::is_same<OutElementwiseOperation, input_right_pads,
ck::tensor_operation::cpu::element_wise::PassThrough>::value) in_element_op,
{ wei_element_op,
str << "_" << OutElementwiseOperation::Name(); out_element_op);
} }
// clang-format on
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
return str.str(); {
} return std::make_unique<Invoker>(Invoker{});
}; }
} // namespace device std::string GetTypeString() const override
} // namespace cpu {
} // namespace tensor_operation auto str = std::stringstream();
} // namespace ck auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
#endif return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP #ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP #define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#include "device.hpp" #include "device.hpp"
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp" #include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp" #include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp" #include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp" #include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp" #include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename BiasDataType, typename BiasDataType,
typename AddDataType, typename AddDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization, ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3) ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
using DeviceOp = using DeviceOp =
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = InDataType; using ADataType = InDataType;
using BDataType = WeiDataType; using BDataType = WeiDataType;
using CDataType = OutDataType; using CDataType = OutDataType;
using C0DataType = BiasDataType; using C0DataType = BiasDataType;
using C1DataType = AddDataType; using C1DataType = AddDataType;
using AElementwiseOperation = InElementwiseOperation; using AElementwiseOperation = InElementwiseOperation;
using BElementwiseOperation = WeiElementwiseOperation; using BElementwiseOperation = WeiElementwiseOperation;
using CElementwiseOperation = OutElementwiseOperation; using CElementwiseOperation = OutElementwiseOperation;
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
static constexpr index_t NDimSpatial = NumDimSpatial; static constexpr index_t NDimSpatial = NumDimSpatial;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() static constexpr auto GetBlockMNKAccessOrder()
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver || if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK) BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{}; return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN) else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{}; return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder()); using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
{ {
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch< return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{}; NonTemporalStore>{};
} }
else if constexpr(MPerThread == 6 && NPerThread == 16) else if constexpr(MPerThread == 6 && NPerThread == 16)
{ {
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch< return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{}; NonTemporalStore>{};
} }
else else
{ {
// static_assert(false, "invalid Mr/Nr"); // static_assert(false, "invalid Mr/Nr");
} }
} }
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch()); using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor() static constexpr auto GetInputBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); if constexpr(UseALocalBuffer)
} {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
static constexpr auto GetWeightBlockDescriptor() }
{ else
return make_naive_tensor_descriptor_packed(make_tuple( {
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), return AGridDesc{};
KPerBlock, }
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); }
}
static constexpr auto GetWeightBlockDescriptor()
static constexpr auto GetOutputBlockDescriptor() {
{ if constexpr(UseBLocalBuffer)
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); {
} return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) KPerBlock,
{ ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
ck::index_t gemm_n_padded = }
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); else
const auto wei_gemm_n_k_grid_desc = {
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); return BGridDesc{};
}
const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor( }
wei_gemm_n_k_grid_desc,
make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n), static constexpr auto GetOutputBlockDescriptor()
make_pass_through_transform(gemm_k)), {
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), if constexpr(UseCLocalBuffer)
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor( }
wei_gemm_padn_k_grid_desc, else
ck::make_tuple( {
ck::make_unmerge_transform( return CGridDesc{};
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) / }
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize, }
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))), static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), {
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return wei_gemm_n0_k_n1_grid_desc; const auto wei_gemm_n_k_grid_desc =
} make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n) const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
{ wei_gemm_n_k_grid_desc,
const auto out_gemm_m_n_grid_desc = make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n)); make_pass_through_transform(gemm_k)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
return out_gemm_m_n_grid_desc; ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
}
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n) wei_gemm_padn_k_grid_desc,
{ ck::make_tuple(
if constexpr(BiasAlongGemmM) ck::make_unmerge_transform(
{ ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
return make_naive_tensor_descriptor_packed(make_tuple(gemm_m)); ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
} ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
else ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
{ ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n)); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
}
} return wei_gemm_n0_k_n1_grid_desc;
}
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N, static auto GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
ck::index_t C, {
ck::index_t gemm_m, const auto out_gemm_m_n_grid_desc =
ck::index_t gemm_k, make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths, return out_gemm_m_n_grid_desc;
const std::vector<ck::index_t>& output_spatial_lengths, }
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations, static auto MakeBiasTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n)
const std::vector<ck::index_t>& input_left_pads, {
const std::vector<ck::index_t>& input_right_pads) if constexpr(BiasAlongGemmM)
{ {
const index_t Wi = input_spatial_lengths[0]; return make_naive_tensor_descriptor_packed(make_tuple(gemm_m));
const index_t Wo = output_spatial_lengths[0]; }
const index_t ConvStrideW = conv_filter_strides[0]; else
{
if constexpr(ConvForwardSpecialization == return make_naive_tensor_descriptor_packed(make_tuple(gemm_n));
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) }
{ }
const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N,
return in_gemm_m_k_grid_desc; ck::index_t C,
} ck::index_t gemm_m,
else if constexpr(ConvForwardSpecialization == ck::index_t gemm_k,
ConvolutionForwardSpecialization_t::Filter1x1Pad0) const std::vector<ck::index_t>& input_spatial_lengths,
{ const std::vector<ck::index_t>& filter_spatial_lengths,
const auto in_n_wi_c_grid_desc = const std::vector<ck::index_t>& output_spatial_lengths,
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( const std::vector<ck::index_t>& input_left_pads,
in_n_wi_c_grid_desc, const std::vector<ck::index_t>& input_right_pads)
make_tuple(make_pass_through_transform(N), {
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), const index_t Wi = input_spatial_lengths[0];
make_pass_through_transform(C)), const index_t Wo = output_spatial_lengths[0];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), const index_t ConvStrideW = conv_filter_strides[0];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
if constexpr(ConvForwardSpecialization ==
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
in_n_wo_c_grid_desc, {
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), const auto in_gemm_m_k_grid_desc =
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemm_m_k_grid_desc;
return in_gemm_m_k_grid_desc; }
} else if constexpr(ConvForwardSpecialization ==
else ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{ {
const index_t X = filter_spatial_lengths[0]; const auto in_n_wi_c_grid_desc =
const index_t ConvDilationW = conv_filter_dilations[0]; make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0]; const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
const auto in_n_wi_c_grid_desc = make_tuple(make_pass_through_transform(N),
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
in_n_wi_c_grid_desc, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW), const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
make_pass_through_transform(C)), in_n_wo_c_grid_desc,
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc, return in_gemm_m_k_grid_desc;
make_tuple( }
make_pass_through_transform(N), else
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), {
make_pass_through_transform(C)), const index_t X = filter_spatial_lengths[0];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), const index_t ConvDilationW = conv_filter_dilations[0];
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc, const auto in_n_wi_c_grid_desc =
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1>{})); in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
return in_gemm_m_k_grid_desc; make_pad_transform(Wi, InLeftPadW, InRightPadW),
} make_pass_through_transform(C)),
} make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
static auto GetInputTensorDescriptor(ck::index_t N, const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
ck::index_t C, in_n_wip_c_grid_desc,
ck::index_t gemm_m, make_tuple(
ck::index_t gemm_k, make_pass_through_transform(N),
const std::vector<ck::index_t>& input_spatial_lengths, make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
const std::vector<ck::index_t>& filter_spatial_lengths, make_pass_through_transform(C)),
const std::vector<ck::index_t>& output_spatial_lengths, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
const std::vector<ck::index_t>& conv_filter_strides, make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads, const auto in_gemm_m_k_grid_desc =
const std::vector<ck::index_t>& input_right_pads) transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
{ make_tuple(make_merge_transform(make_tuple(N, Wo)),
const index_t Hi = input_spatial_lengths[0]; make_merge_transform(make_tuple(X, C))),
const index_t Wi = input_spatial_lengths[1]; make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1]; return in_gemm_m_k_grid_desc;
}
const index_t ConvStrideH = conv_filter_strides[0]; }
const index_t ConvStrideW = conv_filter_strides[1];
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
if constexpr(ConvForwardSpecialization == static auto GetInputTensorDescriptor(ck::index_t N,
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ck::index_t C,
{ ck::index_t gemm_m,
const auto in_gemm_m_k_grid_desc = ck::index_t gemm_k,
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
return in_gemm_m_k_grid_desc; const std::vector<ck::index_t>& output_spatial_lengths,
} const std::vector<ck::index_t>& conv_filter_strides,
else if constexpr(ConvForwardSpecialization == const std::vector<ck::index_t>& conv_filter_dilations,
ConvolutionForwardSpecialization_t::Filter1x1Pad0) const std::vector<ck::index_t>& input_left_pads,
{ const std::vector<ck::index_t>& input_right_pads)
const auto in_n_hi_wi_c_grid_desc = {
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc, const index_t Ho = output_spatial_lengths[0];
make_tuple(make_pass_through_transform(N), const index_t Wo = output_spatial_lengths[1];
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), const index_t ConvStrideH = conv_filter_strides[0];
make_pass_through_transform(C)), const index_t ConvStrideW = conv_filter_strides[1];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
const auto in_gemm_m_k_grid_desc = {
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc, const auto in_gemm_m_k_grid_desc =
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), return in_gemm_m_k_grid_desc;
make_tuple(Sequence<0>{}, Sequence<1>{})); }
else if constexpr(ConvForwardSpecialization ==
return in_gemm_m_k_grid_desc; ConvolutionForwardSpecialization_t::Filter1x1Pad0)
} {
else const auto in_n_hi_wi_c_grid_desc =
{ make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1]; const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
const index_t ConvDilationH = conv_filter_dilations[0]; make_tuple(make_pass_through_transform(N),
const index_t ConvDilationW = conv_filter_dilations[1]; make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
const index_t InLeftPadH = input_left_pads[0]; make_pass_through_transform(C)),
const index_t InLeftPadW = input_left_pads[1]; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1]; const auto in_gemm_m_k_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
const auto in_n_hi_wi_c_grid_desc = make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( make_tuple(Sequence<0>{}, Sequence<1>{}));
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N), return in_gemm_m_k_grid_desc;
make_pad_transform(Hi, InLeftPadH, InRightPadH), }
make_pad_transform(Wi, InLeftPadW, InRightPadW), else
make_pass_through_transform(C)), {
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const index_t Y = filter_spatial_lengths[0];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const index_t X = filter_spatial_lengths[1];
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const index_t ConvDilationH = conv_filter_dilations[0];
in_n_hip_wip_c_grid_desc, const index_t ConvDilationW = conv_filter_dilations[1];
make_tuple(
make_pass_through_transform(N), const index_t InLeftPadH = input_left_pads[0];
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), const index_t InLeftPadW = input_left_pads[1];
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), const index_t InRightPadH = input_right_pads[0];
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), const index_t InRightPadW = input_right_pads[1];
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_hi_wi_c_grid_desc =
const auto in_gemm_m_k_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
make_merge_transform(make_tuple(Y, X, C))), in_n_hi_wi_c_grid_desc,
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(make_pass_through_transform(N),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
return in_gemm_m_k_grid_desc; make_pass_through_transform(C)),
} make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
} make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
static auto GetInputTensorDescriptor(ck::index_t N, in_n_hip_wip_c_grid_desc,
ck::index_t C, make_tuple(
ck::index_t gemm_m, make_pass_through_transform(N),
ck::index_t gemm_k, make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
ck::index_t gemm_m_pad, make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
const std::vector<ck::index_t>& input_spatial_lengths, make_pass_through_transform(C)),
const std::vector<ck::index_t>& filter_spatial_lengths, make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
const std::vector<ck::index_t>& output_spatial_lengths, make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations, const auto in_gemm_m_k_grid_desc =
const std::vector<ck::index_t>& input_left_pads, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
const std::vector<ck::index_t>& input_right_pads) make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
{ make_merge_transform(make_tuple(Y, X, C))),
const index_t Di = input_spatial_lengths[0]; make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
const index_t Hi = input_spatial_lengths[1]; make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t Wi = input_spatial_lengths[2];
return in_gemm_m_k_grid_desc;
const index_t Do = output_spatial_lengths[0]; }
const index_t Ho = output_spatial_lengths[1]; }
const index_t Wo = output_spatial_lengths[2];
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
const index_t ConvStrideD = conv_filter_strides[0]; static auto GetInputTensorDescriptor(ck::index_t N,
const index_t ConvStrideH = conv_filter_strides[1]; ck::index_t C,
const index_t ConvStrideW = conv_filter_strides[2]; ck::index_t gemm_m,
ck::index_t gemm_k,
if constexpr(ConvForwardSpecialization == ck::index_t gemm_m_pad,
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) const std::vector<ck::index_t>& input_spatial_lengths,
{ const std::vector<ck::index_t>& filter_spatial_lengths,
const auto in_gemm_m_k_grid_desc = const std::vector<ck::index_t>& output_spatial_lengths,
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
return in_gemm_m_k_grid_desc; const std::vector<ck::index_t>& input_left_pads,
} const std::vector<ck::index_t>& input_right_pads)
else if constexpr(ConvForwardSpecialization == {
ConvolutionForwardSpecialization_t::Filter1x1Pad0) const index_t Di = input_spatial_lengths[0];
{ const index_t Hi = input_spatial_lengths[1];
const auto in_n_di_hi_wi_c_grid_desc = const index_t Wi = input_spatial_lengths[2];
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const index_t Do = output_spatial_lengths[0];
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( const index_t Ho = output_spatial_lengths[1];
in_n_di_hi_wi_c_grid_desc, const index_t Wo = output_spatial_lengths[2];
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), const index_t ConvStrideD = conv_filter_strides[0];
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), const index_t ConvStrideH = conv_filter_strides[1];
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), const index_t ConvStrideW = conv_filter_strides[2];
make_pass_through_transform(C)),
make_tuple( if constexpr(ConvForwardSpecialization ==
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
make_tuple( {
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); const auto in_gemm_m_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc, return in_gemm_m_k_grid_desc;
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), }
make_pass_through_transform(C)), else if constexpr(ConvForwardSpecialization ==
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), ConvolutionForwardSpecialization_t::Filter1x1Pad0)
make_tuple(Sequence<0>{}, Sequence<1>{})); {
const auto in_n_di_hi_wi_c_grid_desc =
return in_gemm_m_k_grid_desc; make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
}
else const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
{ in_n_di_hi_wi_c_grid_desc,
const index_t Z = filter_spatial_lengths[0]; make_tuple(make_pass_through_transform(N),
const index_t Y = filter_spatial_lengths[1]; make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
const index_t X = filter_spatial_lengths[2]; make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
const index_t ConvDilationD = conv_filter_dilations[0]; make_pass_through_transform(C)),
const index_t ConvDilationH = conv_filter_dilations[1]; make_tuple(
const index_t ConvDilationW = conv_filter_dilations[2]; Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
const index_t InLeftPadD = input_left_pads[0]; Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2]; const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
const index_t InRightPadD = input_right_pads[0]; make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
const index_t InRightPadH = input_right_pads[1]; make_pass_through_transform(C)),
const index_t InRightPadW = input_right_pads[2]; make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); return in_gemm_m_k_grid_desc;
}
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( else
in_n_di_hi_wi_c_grid_desc, {
make_tuple(make_pass_through_transform(N), const index_t Z = filter_spatial_lengths[0];
make_pad_transform(Di, InLeftPadD, InRightPadD), const index_t Y = filter_spatial_lengths[1];
make_pad_transform(Hi, InLeftPadH, InRightPadH), const index_t X = filter_spatial_lengths[2];
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)), const index_t ConvDilationD = conv_filter_dilations[0];
make_tuple( const index_t ConvDilationH = conv_filter_dilations[1];
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), const index_t ConvDilationW = conv_filter_dilations[2];
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( const index_t InLeftPadW = input_left_pads[2];
in_n_hip_wip_c_grid_desc,
make_tuple( const index_t InRightPadD = input_right_pads[0];
make_pass_through_transform(N), const index_t InRightPadH = input_right_pads[1];
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), const index_t InRightPadW = input_right_pads[2];
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), const auto in_n_di_hi_wi_c_grid_desc =
make_pass_through_transform(C)), make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, in_n_di_hi_wi_c_grid_desc,
Sequence<1, 2>{}, make_tuple(make_pass_through_transform(N),
Sequence<3, 4>{}, make_pad_transform(Di, InLeftPadD, InRightPadD),
Sequence<5, 6>{}, make_pad_transform(Hi, InLeftPadH, InRightPadH),
Sequence<7>{})); make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor( make_tuple(
in_n_z_do_y_ho_x_wo_c_grid_desc, Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_tuple(
make_merge_transform(make_tuple(Z, Y, X, C))), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
return in_gemm_m_k_grid_desc; make_tuple(
} make_pass_through_transform(N),
} make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths) make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
{ make_pass_through_transform(C)),
return N * std::accumulate(std::begin(output_spatial_lengths), make_tuple(
std::end(output_spatial_lengths), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
1, make_tuple(Sequence<0>{},
std::multiplies<ck::index_t>()); Sequence<1, 2>{},
} Sequence<3, 4>{},
Sequence<5, 6>{},
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths) Sequence<7>{}));
{
return C * std::accumulate(std::begin(filter_spatial_lengths), const auto in_gemm_m_k_grid_desc = transform_tensor_descriptor(
std::end(filter_spatial_lengths), in_n_z_do_y_ho_x_wo_c_grid_desc,
1, make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
std::multiplies<ck::index_t>()); make_merge_transform(make_tuple(Z, Y, X, C))),
} make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
static index_t GetGemmN(ck::index_t K)
{ return in_gemm_m_k_grid_desc;
// return ck::math::integer_least_multiple(K, }
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); }
return K;
} static index_t GetGemmM(ck::index_t N, const std::vector<ck::index_t>& output_spatial_lengths)
{
static auto MakeABCGridDescriptor(ck::index_t N, return N * std::accumulate(std::begin(output_spatial_lengths),
ck::index_t K, std::end(output_spatial_lengths),
ck::index_t C, 1,
std::vector<ck::index_t> input_spatial_lengths, std::multiplies<ck::index_t>());
std::vector<ck::index_t> filter_spatial_lengths, }
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
std::vector<ck::index_t> conv_filter_dilations, {
std::vector<ck::index_t> input_left_pads, return C * std::accumulate(std::begin(filter_spatial_lengths),
std::vector<ck::index_t> input_right_pads) std::end(filter_spatial_lengths),
{ 1,
using namespace ck; std::multiplies<ck::index_t>());
}
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = GetGemmN(K); static index_t GetGemmN(ck::index_t K)
const index_t GemmK = GetGemmK(C, filter_spatial_lengths); {
// return ck::math::integer_least_multiple(K,
// A: // ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto in_gemm_m_k_grid_desc = return K;
GetInputTensorDescriptor<NumDimSpatial>(N, }
C,
GemmM, static auto MakeABCGridDescriptor(ck::index_t N,
GemmK, ck::index_t K,
input_spatial_lengths, ck::index_t C,
filter_spatial_lengths, std::vector<ck::index_t> input_spatial_lengths,
output_spatial_lengths, std::vector<ck::index_t> filter_spatial_lengths,
conv_filter_strides, std::vector<ck::index_t> output_spatial_lengths,
conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
input_right_pads); std::vector<ck::index_t> input_left_pads,
// B: std::vector<ck::index_t> input_right_pads)
const auto wei_gemm_n0_k_n1_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN); {
// C: using namespace ck;
const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
return make_tuple( const index_t GemmN = GetGemmN(K);
in_gemm_m_k_grid_desc, wei_gemm_n0_k_n1_grid_desc, out_gemm_m_n_grid_desc); const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
}
// A:
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> const auto in_gemm_m_k_grid_desc =
static auto GetABCGridDesc() GetInputTensorDescriptor<NumDimSpatial>(N,
{ C,
return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); GemmM,
} GemmK,
input_spatial_lengths,
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> filter_spatial_lengths,
static auto GetABCGridDesc() output_spatial_lengths,
{ conv_filter_strides,
return MakeABCGridDescriptor( conv_filter_dilations,
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); input_left_pads,
} input_right_pads);
// B:
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> const auto wei_gemm_n0_k_n1_grid_desc = GetWeightTensorDescriptor(GemmK, GemmN);
static auto GetABCGridDesc() // C:
{ const auto out_gemm_m_n_grid_desc = GetOutputTensorDescriptor(GemmM, GemmN);
return MakeABCGridDescriptor(
1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); return make_tuple(
} in_gemm_m_k_grid_desc, wei_gemm_n0_k_n1_grid_desc, out_gemm_m_n_grid_desc);
}
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; static auto GetABCGridDesc()
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; {
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; return MakeABCGridDescriptor(1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>; }
using C1GridDesc = CGridDesc;
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
// static constexpr bool UseCLocalBuffer = false; static auto GetABCGridDesc()
{
using AThreadwiseCopy = return MakeABCGridDescriptor(
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC< 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
ADataType, }
ADataType,
AGridDesc, template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
decltype(GetInputBlockDescriptor()), static auto GetABCGridDesc()
InElementwiseOperation, {
false, return MakeABCGridDescriptor(
ConvForwardSpecialization, 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
GemmKSpecialization>; }
using BThreadwiseCopy = using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
BDataType, using AGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
BDataType, using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
BGridDesc, using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
decltype(GetWeightBlockDescriptor()), using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
WeiElementwiseOperation, using C1GridDesc = CGridDesc;
false,
ConvForwardSpecialization, // static constexpr bool UseCLocalBuffer = false;
GemmKSpecialization>;
using AThreadwiseCopy =
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< ADataType,
CDataType, ADataType,
C0DataType, AGridDesc,
C1DataType, decltype(GetInputBlockDescriptor()),
CDataType, InElementwiseOperation,
CGridDesc, !UseALocalBuffer,
C0GridDesc, ConvForwardSpecialization,
C1GridDesc, GemmKSpecialization>;
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, using BThreadwiseCopy =
!UseCLocalBuffer, ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC<
BiasAlongGemmM>; BDataType,
BDataType,
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< BGridDesc,
ADataType, // InDataType, decltype(GetWeightBlockDescriptor()),
BDataType, // WeiDataType, WeiElementwiseOperation,
CDataType, // OutDataType, !UseBLocalBuffer,
C0DataType, // C0DataType ConvForwardSpecialization,
C1DataType, // C1DataType GemmKSpecialization>;
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc, using CThreadwiseCopy =
CGridDesc, // CGridDesc, ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
C0GridDesc, // C0GridDesc, CDataType,
C1GridDesc, // C1GridDesc, C0DataType,
AElementwiseOperation, // AElementwiseOperation, C1DataType,
BElementwiseOperation, // BElementwiseOperation, CDataType,
CElementwiseOperation, // CElementwiseOperation, CGridDesc,
MPerBlock, // MPerBlock, C0GridDesc,
NPerBlock, // NPerBlock, C1GridDesc,
KPerBlock, // KPerBlock, decltype(GetOutputBlockDescriptor()),
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, OutElementwiseOperation,
AThreadwiseCopy, // AThreadwiseCopy !UseCLocalBuffer,
BThreadwiseCopy, // BThreadwiseCopy BiasAlongGemmM>;
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder, using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ck::Sequence<0, 1>, // ThreadMNAccessOrder ADataType, // InDataType,
UseALocalBuffer, // UseALocalBuffer BDataType, // WeiDataType,
UseBLocalBuffer, // UseBLocalBuffer CDataType, // OutDataType,
UseCLocalBuffer // UseCLocalBuffer C0DataType, // C0DataType
>; C1DataType, // C1DataType
AGridDesc, // AGridDesc,
// Argument BGridDesc, // BGridDesc,
struct Argument : public BaseArgument CGridDesc, // CGridDesc,
{ C0GridDesc, // C0GridDesc,
Argument(const InDataType* p_in_grid, C1GridDesc, // C1GridDesc,
const WeiDataType* p_wei_grid, AElementwiseOperation, // AElementwiseOperation,
OutDataType* p_out_grid, BElementwiseOperation, // BElementwiseOperation,
const BiasDataType* p_bias_grid, CElementwiseOperation, // CElementwiseOperation,
const AddDataType* p_add_grid, MPerBlock, // MPerBlock,
ck::index_t N, NPerBlock, // NPerBlock,
ck::index_t K, KPerBlock, // KPerBlock,
ck::index_t C, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
std::vector<ck::index_t> input_spatial_lengths, AThreadwiseCopy, // AThreadwiseCopy
std::vector<ck::index_t> filter_spatial_lengths, BThreadwiseCopy, // BThreadwiseCopy
std::vector<ck::index_t> output_spatial_lengths, CThreadwiseCopy, // CThreadwiseCopy
std::vector<ck::index_t> conv_filter_strides, BlockMNKAccessOrder, // BlockMNKAccessOrder,
std::vector<ck::index_t> conv_filter_dilations, ck::Sequence<0, 1>, // ThreadMNAccessOrder
std::vector<ck::index_t> input_left_pads, UseALocalBuffer, // UseALocalBuffer
std::vector<ck::index_t> input_right_pads, UseBLocalBuffer, // UseBLocalBuffer
InElementwiseOperation in_element_op, UseCLocalBuffer // UseCLocalBuffer
WeiElementwiseOperation wei_element_op, >;
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid}, // Argument
p_b_grid_{p_wei_grid}, struct Argument : public BaseArgument
p_c_grid_{p_out_grid}, {
p_c0_grid_{p_bias_grid}, Argument(const InDataType* p_in_grid,
p_c1_grid_{p_add_grid}, const WeiDataType* p_wei_grid,
a_grid_desc_{}, OutDataType* p_out_grid,
b_grid_desc_{}, const BiasDataType* p_bias_grid,
c_grid_desc_{}, const AddDataType* p_add_grid,
c0_grid_desc_{}, ck::index_t N,
c1_grid_desc_{}, ck::index_t K,
a_element_op_{in_element_op}, ck::index_t C,
b_element_op_{wei_element_op}, std::vector<ck::index_t> input_spatial_lengths,
c_element_op_{out_element_op}, std::vector<ck::index_t> filter_spatial_lengths,
Conv_N_{N}, std::vector<ck::index_t> output_spatial_lengths,
Conv_K_{K}, std::vector<ck::index_t> conv_filter_strides,
Conv_C_{C}, std::vector<ck::index_t> conv_filter_dilations,
filter_spatial_lengths_{filter_spatial_lengths}, std::vector<ck::index_t> input_left_pads,
conv_filter_strides_{conv_filter_strides}, std::vector<ck::index_t> input_right_pads,
input_left_pads_{input_left_pads}, InElementwiseOperation in_element_op,
input_right_pads_{input_right_pads} WeiElementwiseOperation wei_element_op,
{ OutElementwiseOperation out_element_op)
const auto descs = DeviceOp::MakeABCGridDescriptor(N, : p_a_grid_{p_in_grid},
K, p_b_grid_{p_wei_grid},
C, p_c_grid_{p_out_grid},
input_spatial_lengths, p_c0_grid_{p_bias_grid},
filter_spatial_lengths, p_c1_grid_{p_add_grid},
output_spatial_lengths, a_grid_desc_{},
conv_filter_strides, b_grid_desc_{},
conv_filter_dilations, c_grid_desc_{},
input_left_pads, c0_grid_desc_{},
input_right_pads); c1_grid_desc_{},
a_grid_desc_ = descs[I0]; a_element_op_{in_element_op},
b_grid_desc_ = descs[I1]; b_element_op_{wei_element_op},
c_grid_desc_ = descs[I2]; c_element_op_{out_element_op},
Conv_N_{N},
c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths), Conv_K_{K},
GetGemmN(K)); Conv_C_{C},
c1_grid_desc_ = descs[I2]; filter_spatial_lengths_{filter_spatial_lengths},
} conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
// private: input_right_pads_{input_right_pads}
const ADataType* p_a_grid_; {
const BDataType* p_b_grid_; const auto descs = DeviceOp::MakeABCGridDescriptor(N,
CDataType* p_c_grid_; K,
const C0DataType* p_c0_grid_; C,
const C1DataType* p_c1_grid_; input_spatial_lengths,
AGridDesc a_grid_desc_; filter_spatial_lengths,
BGridDesc b_grid_desc_; output_spatial_lengths,
CGridDesc c_grid_desc_; conv_filter_strides,
C0GridDesc c0_grid_desc_; conv_filter_dilations,
C1GridDesc c1_grid_desc_; input_left_pads,
input_right_pads);
AElementwiseOperation a_element_op_; a_grid_desc_ = descs[I0];
BElementwiseOperation b_element_op_; b_grid_desc_ = descs[I1];
CElementwiseOperation c_element_op_; c_grid_desc_ = descs[I2];
// for checking IsSupportedArgument()
index_t Conv_N_; c0_grid_desc_ = DeviceOp::MakeBiasTensorDescriptor(GetGemmM(N, output_spatial_lengths),
index_t Conv_K_; GetGemmN(K));
index_t Conv_C_; c1_grid_desc_ = descs[I2];
std::vector<index_t> filter_spatial_lengths_; }
std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_; // private:
std::vector<index_t> input_right_pads_; const ADataType* p_a_grid_;
}; const BDataType* p_b_grid_;
CDataType* p_c_grid_;
// Invoker const C0DataType* p_c0_grid_;
struct Invoker : public BaseInvoker const C1DataType* p_c1_grid_;
{ AGridDesc a_grid_desc_;
using Argument = DeviceOp::Argument; BGridDesc b_grid_desc_;
CGridDesc c_grid_desc_;
float Run(const Argument& arg, C0GridDesc c0_grid_desc_;
const StreamConfig& stream_config = StreamConfig{}, C1GridDesc c1_grid_desc_;
int nrepeat = 1)
{ AElementwiseOperation a_element_op_;
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) BElementwiseOperation b_element_op_;
{ CElementwiseOperation c_element_op_;
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); // for checking IsSupportedArgument()
} index_t Conv_N_;
index_t Conv_K_;
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); index_t Conv_C_;
std::vector<index_t> filter_spatial_lengths_;
const auto kernel = std::vector<index_t> conv_filter_strides_;
ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm, std::vector<index_t> input_left_pads_;
ADataType, std::vector<index_t> input_right_pads_;
BDataType, };
CDataType,
C0DataType, // Invoker
C1DataType, struct Invoker : public BaseInvoker
AGridDesc, {
BGridDesc, using Argument = DeviceOp::Argument;
CGridDesc,
C0GridDesc, float Run(const Argument& arg,
C1GridDesc, const StreamConfig& stream_config = StreamConfig{},
AElementwiseOperation, int nrepeat = 1)
BElementwiseOperation, {
CElementwiseOperation>; if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
float ave_time = 0; throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
nrepeat,
arg.p_a_grid_, const auto kernel =
arg.p_b_grid_, ck::cpu::kernel_gemm_bias_activation_add_avx_mxn<GridwiseGemm,
arg.p_c_grid_, ADataType,
arg.p_c0_grid_, BDataType,
arg.p_c1_grid_, CDataType,
arg.a_grid_desc_, C0DataType,
arg.b_grid_desc_, C1DataType,
arg.c_grid_desc_, AGridDesc,
arg.c0_grid_desc_, BGridDesc,
arg.c1_grid_desc_, CGridDesc,
arg.a_element_op_, C0GridDesc,
arg.b_element_op_, C1GridDesc,
arg.c_element_op_); AElementwiseOperation,
BElementwiseOperation,
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the CElementwiseOperation>;
// result
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); float ave_time = 0;
launch_cpu_kernel(kernel, if(nrepeat != 1)
arg.p_a_grid_, ave_time = launch_and_time_cpu_kernel(kernel,
arg.p_b_grid_, nrepeat,
arg.p_c_grid_, arg.p_a_grid_,
arg.p_c0_grid_, arg.p_b_grid_,
arg.p_c1_grid_, arg.p_c_grid_,
arg.a_grid_desc_, arg.p_c0_grid_,
arg.b_grid_desc_, arg.p_c1_grid_,
arg.c_grid_desc_, arg.a_grid_desc_,
arg.c0_grid_desc_, arg.b_grid_desc_,
arg.c1_grid_desc_, arg.c_grid_desc_,
arg.a_element_op_, arg.c0_grid_desc_,
arg.b_element_op_, arg.c1_grid_desc_,
arg.c_element_op_); arg.a_element_op_,
arg.b_element_op_,
return ave_time; arg.c_element_op_);
}
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
float Run(const BaseArgument* p_arg, // result
const StreamConfig& stream_config = StreamConfig{}, memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
int nrepeat = 1) override
{ launch_cpu_kernel(kernel,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat); arg.p_a_grid_,
} arg.p_b_grid_,
}; arg.p_c_grid_,
arg.p_c0_grid_,
static constexpr bool IsValidCompilationParameter() arg.p_c1_grid_,
{ arg.a_grid_desc_,
// TODO: properly implement this check arg.b_grid_desc_,
return true; arg.c_grid_desc_,
} arg.c0_grid_desc_,
arg.c1_grid_desc_,
static bool IsSupportedArgument(const Argument& arg) arg.a_element_op_,
{ arg.b_element_op_,
if constexpr(ConvForwardSpecialization == arg.c_element_op_);
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ return ave_time;
// check if it's 1x1, stride=1 conv }
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && float Run(const BaseArgument* p_arg,
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && const StreamConfig& stream_config = StreamConfig{},
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) int nrepeat = 1) override
{ {
return false; return Run(*dynamic_cast<const Argument*>(p_arg), stream_config, nrepeat);
} }
} };
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) static constexpr bool IsValidCompilationParameter()
{ {
// check if it's 1x1 conv // TODO: properly implement this check
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && return true;
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && }
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
{ static bool IsSupportedArgument(const Argument& arg)
return false; {
} if constexpr(ConvForwardSpecialization ==
} ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if constexpr(GemmKSpecialization == // check if it's 1x1, stride=1 conv
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
{ arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
if(!(arg.Conv_C_ % KPerBlock == 0)) arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
return false; arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
} {
return false;
// Gridwise GEMM size }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); }
} else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
bool IsSupportedArgument(const BaseArgument* p_arg) override {
{ // check if it's 1x1 conv
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
} arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
static auto MakeArgument(const InDataType* p_in_grid, {
const WeiDataType* p_wei_grid, return false;
OutDataType* p_out_grid, }
const BiasDataType* p_bias_grid, }
const AddDataType* p_add_grid,
ck::index_t N, if constexpr(GemmKSpecialization ==
ck::index_t K, ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ck::index_t C, ConvForwardSpecialization !=
std::vector<ck::index_t> input_spatial_lengths, ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
std::vector<ck::index_t> filter_spatial_lengths, {
std::vector<ck::index_t> output_spatial_lengths, if(!(arg.Conv_C_ % KPerBlock == 0))
std::vector<ck::index_t> conv_filter_strides, return false;
std::vector<ck::index_t> conv_filter_dilations, }
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, if constexpr((!UseALocalBuffer || !UseBLocalBuffer) &&
InElementwiseOperation in_element_op, ConvForwardSpecialization !=
WeiElementwiseOperation wei_element_op, ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
OutElementwiseOperation out_element_op) {
{ // TODO: We can support this in the future, as long as figure out how to express tensor
return Argument{p_in_grid, // transform
p_wei_grid, return false;
p_out_grid, }
p_bias_grid,
p_add_grid, // Gridwise GEMM size
N, return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
K, }
C,
input_spatial_lengths, bool IsSupportedArgument(const BaseArgument* p_arg) override
filter_spatial_lengths, {
output_spatial_lengths, return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
conv_filter_strides, }
conv_filter_dilations,
input_left_pads, static auto MakeArgument(const InDataType* p_in_grid,
input_right_pads, const WeiDataType* p_wei_grid,
in_element_op, OutDataType* p_out_grid,
wei_element_op, const BiasDataType* p_bias_grid,
out_element_op}; const AddDataType* p_add_grid,
} ck::index_t N,
ck::index_t K,
static auto MakeInvoker() { return Invoker{}; } ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::unique_ptr<BaseArgument> std::vector<ck::index_t> filter_spatial_lengths,
MakeArgumentPointer(const void* p_in_grid, std::vector<ck::index_t> output_spatial_lengths,
const void* p_wei_grid, std::vector<ck::index_t> conv_filter_strides,
void* p_out_grid, std::vector<ck::index_t> conv_filter_dilations,
const void* p_bias_grid, std::vector<ck::index_t> input_left_pads,
const void* p_add_grid, std::vector<ck::index_t> input_right_pads,
ck::index_t N, InElementwiseOperation in_element_op,
ck::index_t K, WeiElementwiseOperation wei_element_op,
ck::index_t C, OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_spatial_lengths, {
std::vector<ck::index_t> filter_spatial_lengths, return Argument{p_in_grid,
std::vector<ck::index_t> output_spatial_lengths, p_wei_grid,
std::vector<ck::index_t> conv_filter_strides, p_out_grid,
std::vector<ck::index_t> conv_filter_dilations, p_bias_grid,
std::vector<ck::index_t> input_left_pads, p_add_grid,
std::vector<ck::index_t> input_right_pads, N,
InElementwiseOperation in_element_op, K,
WeiElementwiseOperation wei_element_op, C,
OutElementwiseOperation out_element_op) override input_spatial_lengths,
{ filter_spatial_lengths,
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), output_spatial_lengths,
static_cast<const WeiDataType*>(p_wei_grid), conv_filter_strides,
static_cast<OutDataType*>(p_out_grid), conv_filter_dilations,
static_cast<const BiasDataType*>(p_bias_grid), input_left_pads,
static_cast<const AddDataType*>(p_add_grid), input_right_pads,
N, in_element_op,
K, wei_element_op,
C, out_element_op};
input_spatial_lengths, }
filter_spatial_lengths,
output_spatial_lengths, static auto MakeInvoker() { return Invoker{}; }
conv_filter_strides,
conv_filter_dilations, std::unique_ptr<BaseArgument>
input_left_pads, MakeArgumentPointer(const void* p_in_grid,
input_right_pads, const void* p_wei_grid,
in_element_op, void* p_out_grid,
wei_element_op, const void* p_bias_grid,
out_element_op); const void* p_add_grid,
} ck::index_t N,
ck::index_t K,
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override ck::index_t C,
{ std::vector<ck::index_t> input_spatial_lengths,
return std::make_unique<Invoker>(Invoker{}); std::vector<ck::index_t> filter_spatial_lengths,
} std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::string GetTypeString() const override std::vector<ck::index_t> conv_filter_dilations,
{ std::vector<ck::index_t> input_left_pads,
auto str = std::stringstream(); std::vector<ck::index_t> input_right_pads,
auto string_local_buffer = [](bool is_local_buffer) { InElementwiseOperation in_element_op,
if(is_local_buffer) WeiElementwiseOperation wei_element_op,
return "L"; OutElementwiseOperation out_element_op) override
else {
return "G"; return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
}; static_cast<const WeiDataType*>(p_wei_grid),
// clang-format off static_cast<OutDataType*>(p_out_grid),
str << "DeviceConv" << std::to_string(NumDimSpatial) static_cast<const BiasDataType*>(p_bias_grid),
<< "DFwd_BAA_Avx2_NHWC_KYXC" static_cast<const AddDataType*>(p_add_grid),
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) N,
<<"_KS"<< static_cast<int>(GemmKSpecialization) K,
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) C,
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock input_spatial_lengths,
<< "_TT" << MPerThread << "x" << NPerThread filter_spatial_lengths,
<< "_A" << string_local_buffer(UseALocalBuffer) output_spatial_lengths,
<< "_B" << string_local_buffer(UseBLocalBuffer) conv_filter_strides,
<< "_C" << string_local_buffer(UseCLocalBuffer) conv_filter_dilations,
; input_left_pads,
if constexpr (!std::is_same<OutElementwiseOperation, input_right_pads,
ck::tensor_operation::cpu::element_wise::PassThrough>::value) in_element_op,
{ wei_element_op,
str << "_" << OutElementwiseOperation::Name(); out_element_op);
} }
// clang-format on
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
return str.str(); {
} return std::make_unique<Invoker>(Invoker{});
}; }
} // namespace device std::string GetTypeString() const override
} // namespace cpu {
} // namespace tensor_operation auto str = std::stringstream();
} // namespace ck auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
#endif return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwd_BAA_Avx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
if constexpr (!std::is_same<OutElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
str << "_" << OutElementwiseOperation::Name();
}
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -116,20 +116,41 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -116,20 +116,41 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
static constexpr auto GetInputBlockDescriptor() static constexpr auto GetInputBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
} }
static constexpr auto GetWeightBlockDescriptor() static constexpr auto GetWeightBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple( if constexpr(UseBLocalBuffer)
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), {
KPerBlock, return make_naive_tensor_descriptor_packed(make_tuple(
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
} }
static constexpr auto GetOutputBlockDescriptor() static constexpr auto GetOutputBlockDescriptor()
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
} }
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
...@@ -563,7 +584,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -563,7 +584,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AGridDesc, AGridDesc,
decltype(GetInputBlockDescriptor()), decltype(GetInputBlockDescriptor()),
InElementwiseOperation, InElementwiseOperation,
false, !UseALocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization,
GemmKSpecialization>; GemmKSpecialization>;
...@@ -574,7 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -574,7 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
BGridDesc, BGridDesc,
decltype(GetWeightBlockDescriptor()), decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation, WeiElementwiseOperation,
false, !UseBLocalBuffer,
ConvForwardSpecialization, ConvForwardSpecialization,
GemmKSpecialization>; GemmKSpecialization>;
...@@ -820,7 +841,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -820,7 +841,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
} }
if constexpr(GemmKSpecialization == if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % KPerBlock == 0))
return false; return false;
...@@ -829,6 +852,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -829,6 +852,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if(!(arg.Conv_K_ % 8 == 0)) if(!(arg.Conv_K_ % 8 == 0))
return false; return false;
if constexpr(!UseALocalBuffer &&
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return false;
}
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
......
...@@ -80,46 +80,65 @@ struct GridwiseGemmAvx2_MxN ...@@ -80,46 +80,65 @@ struct GridwiseGemmAvx2_MxN
// static constexpr auto Avx2RegisterVector = 8; // 8 floats // static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit static constexpr index_t MemAlignmentByte = 32; // 256bit
static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk) static auto GetABlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t k_per_blk,
const AGridDesc& a_grid_desc)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(UseALocalBuffer)
ck::tensor_layout::gemm::RowMajor>::value)
{ {
// A : M, K if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
auto a_block_desc_m_k = ck::tensor_layout::gemm::RowMajor>::value)
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk)); {
return a_block_desc_m_k; // A : M, K
auto a_block_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
return a_block_desc_m_k;
}
else
{
// A : K, M
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(k_per_blk,
math::integer_least_multiple(
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
return a_block_desc_k_m;
}
} }
else else
{ {
// A : K, M return a_grid_desc;
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(k_per_blk,
math::integer_least_multiple(
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
return a_block_desc_k_m;
} }
} }
static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk) static auto GetBBlockDescriptor(const ck::index_t k_per_blk,
const ck::index_t n_per_blk,
const BGridDesc& b_grid_desc)
{ {
// n_per_blk should be 8x if constexpr(UseBLocalBuffer)
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{ {
// B : K, N // n_per_blk should be 8x
auto b_block_desc_k_n = if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk)); ck::tensor_layout::gemm::RowMajor>::value)
return b_block_desc_k_n; {
// B : K, N
auto b_block_desc_k_n =
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
return b_block_desc_k_n;
}
else
{
// B : N/8, K, N8
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(
make_tuple(math::integer_divide_ceil(
n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_block_desc_n0_k_n1;
}
} }
else else
{ {
// B : N/8, K, N8 return b_grid_desc;
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_block_desc_n0_k_n1;
} }
} }
...@@ -262,10 +281,10 @@ struct GridwiseGemmAvx2_MxN ...@@ -262,10 +281,10 @@ struct GridwiseGemmAvx2_MxN
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension(); constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize()); const_cast<FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize()); const_cast<FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize()); reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
...@@ -274,8 +293,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -274,8 +293,8 @@ struct GridwiseGemmAvx2_MxN
FloatA, // FloatA, FloatA, // FloatA,
FloatB, // FloatB, FloatB, // FloatB,
FloatC, // FloatC, FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc, decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc, decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc, decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock, KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
...@@ -320,14 +339,14 @@ struct GridwiseGemmAvx2_MxN ...@@ -320,14 +339,14 @@ struct GridwiseGemmAvx2_MxN
auto a_threadwise_copy = auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc, AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{});
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(), ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block), GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc),
ck::make_zero_multi_index<b_block_copy_dim>(), ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{}); BElementwiseOperation{});
...@@ -338,21 +357,27 @@ struct GridwiseGemmAvx2_MxN ...@@ -338,21 +357,27 @@ struct GridwiseGemmAvx2_MxN
ck::make_zero_multi_index<2>(), ck::make_zero_multi_index<2>(),
CElementwiseOperation{}); CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA), DeviceAlignedMemCPU a_block_mem(
MemAlignmentByte); UseALocalBuffer ? m_per_block * k_per_block * sizeof(FloatA) : 0,
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), MemAlignmentByte);
MemAlignmentByte); DeviceAlignedMemCPU b_block_mem(
UseBLocalBuffer ? k_per_block * n_per_block * sizeof(FloatB) : 0,
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem( DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0, UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
MemAlignmentByte); MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf), UseALocalBuffer ? reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf)
a_block_mem.mMemSize / sizeof(FloatA)); : const_cast<FloatA*>(p_a_grid),
UseALocalBuffer ? a_block_mem.mMemSize / sizeof(FloatA)
: a_grid_desc.GetElementSpaceSize());
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf), UseBLocalBuffer ? reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf)
b_block_mem.mMemSize / sizeof(FloatB)); : const_cast<FloatB*>(p_b_grid),
UseBLocalBuffer ? b_block_mem.mMemSize / sizeof(FloatB)
: b_grid_desc.GetElementSpaceSize());
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf) UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
...@@ -395,8 +420,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -395,8 +420,8 @@ struct GridwiseGemmAvx2_MxN
{ {
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size, a_grid_desc);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size, b_grid_desc);
a_threadwise_copy.RunRead(a_grid_desc, a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf, a_grid_buf,
...@@ -412,12 +437,17 @@ struct GridwiseGemmAvx2_MxN ...@@ -412,12 +437,17 @@ struct GridwiseGemmAvx2_MxN
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
make_zero_multi_index<a_block_copy_dim>(), make_zero_multi_index<a_block_copy_dim>(),
GetASliceLength(mc_size, kc_size),
b_block_desc, b_block_desc,
b_block_buf, b_block_buf,
make_zero_multi_index<b_block_copy_dim>(), make_zero_multi_index<b_block_copy_dim>(),
GetBSliceLength(kc_size, nc_size),
c_block_desc, c_block_desc,
c_block_buf, c_block_buf,
make_zero_multi_index<2>(), make_zero_multi_index<2>(),
GetCSliceLength(mc_size, nc_size),
i_kc != 0); i_kc != 0);
if((i_kc + k_per_block) < GemmK) if((i_kc + k_per_block) < GemmK)
...@@ -450,14 +480,14 @@ struct GridwiseGemmAvx2_MxN ...@@ -450,14 +480,14 @@ struct GridwiseGemmAvx2_MxN
auto a_threadwise_copy = auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc, AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
ck::make_zero_multi_index<a_block_copy_dim>(), ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{}); AElementwiseOperation{});
auto b_threadwise_copy = auto b_threadwise_copy =
BThreadwiseCopy(b_grid_desc, BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(), ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block), GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc),
ck::make_zero_multi_index<b_block_copy_dim>(), ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{}); BElementwiseOperation{});
...@@ -468,21 +498,27 @@ struct GridwiseGemmAvx2_MxN ...@@ -468,21 +498,27 @@ struct GridwiseGemmAvx2_MxN
ck::make_zero_multi_index<2>(), ck::make_zero_multi_index<2>(),
CElementwiseOperation{}); CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA), DeviceAlignedMemCPU a_block_mem(
MemAlignmentByte); UseALocalBuffer ? m_per_block * k_per_block * sizeof(FloatA) : 0,
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), MemAlignmentByte);
MemAlignmentByte); DeviceAlignedMemCPU b_block_mem(
UseBLocalBuffer ? k_per_block * n_per_block * sizeof(FloatB) : 0,
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem( DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0, UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
MemAlignmentByte); MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf), UseALocalBuffer ? reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf)
a_block_mem.mMemSize / sizeof(FloatA)); : const_cast<FloatA*>(p_a_grid),
UseALocalBuffer ? a_block_mem.mMemSize / sizeof(FloatA)
: a_grid_desc.GetElementSpaceSize());
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf), UseBLocalBuffer ? reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf)
b_block_mem.mMemSize / sizeof(FloatB)); : const_cast<FloatB*>(p_b_grid),
UseBLocalBuffer ? b_block_mem.mMemSize / sizeof(FloatB)
: b_grid_desc.GetElementSpaceSize());
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf) UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
...@@ -503,7 +539,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -503,7 +539,7 @@ struct GridwiseGemmAvx2_MxN
{ {
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size, a_grid_desc);
a_threadwise_copy.RunRead(a_grid_desc, a_threadwise_copy.RunRead(a_grid_desc,
a_grid_buf, a_grid_buf,
a_block_desc, a_block_desc,
...@@ -519,7 +555,7 @@ struct GridwiseGemmAvx2_MxN ...@@ -519,7 +555,7 @@ struct GridwiseGemmAvx2_MxN
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple( nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size, b_grid_desc);
b_threadwise_copy.RunRead(b_grid_desc, b_threadwise_copy.RunRead(b_grid_desc,
b_grid_buf, b_grid_buf,
...@@ -543,12 +579,18 @@ struct GridwiseGemmAvx2_MxN ...@@ -543,12 +579,18 @@ struct GridwiseGemmAvx2_MxN
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
make_zero_multi_index<a_block_copy_dim>(), make_zero_multi_index<a_block_copy_dim>(),
GetASliceLength(mc_size, kc_size),
b_block_desc, b_block_desc,
b_block_buf, b_block_buf,
make_zero_multi_index<b_block_copy_dim>(), make_zero_multi_index<b_block_copy_dim>(),
GetBSliceLength(kc_size, nc_size),
c_block_desc, c_block_desc,
c_block_buf, c_block_buf,
make_zero_multi_index<2>(), make_zero_multi_index<2>(),
GetCSliceLength(mc_size, nc_size),
i_kc != 0); i_kc != 0);
if((i_nc + n_per_block) < GemmN) if((i_nc + n_per_block) < GemmN)
......
#ifndef CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP #ifndef CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#define CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP #define CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp" #include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp" #include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp" #include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <utility> #include <utility>
#include <unistd.h> #include <unistd.h>
#include <omp.h> #include <omp.h>
#include <pthread.h> #include <pthread.h>
namespace ck { namespace ck {
namespace cpu { namespace cpu {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename FloatC0, typename FloatC0,
typename FloatC1, typename FloatC1,
typename AGridDesc, typename AGridDesc,
typename BGridDesc, typename BGridDesc,
typename CGridDesc, typename CGridDesc,
typename C0GridDesc, typename C0GridDesc,
typename C1GridDesc, typename C1GridDesc,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid, void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid, const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid, const FloatC1* __restrict__ p_c1_grid,
const AGridDesc& a_grid_desc, const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc, const CGridDesc& c_grid_desc,
const C0GridDesc& c0_grid_desc, const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc, const C1GridDesc& c1_grid_desc,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid, p_c0_grid,
p_c1_grid, p_c1_grid,
a_grid_desc, a_grid_desc,
b_grid_desc, b_grid_desc,
c_grid_desc, c_grid_desc,
c0_grid_desc, c0_grid_desc,
c1_grid_desc, c1_grid_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
} }
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename FloatC0, typename FloatC0,
typename FloatC1, typename FloatC1,
typename AGridDesc, typename AGridDesc,
typename BGridDesc, typename BGridDesc,
typename CGridDesc, typename CGridDesc,
typename C0GridDesc, typename C0GridDesc,
typename C1GridDesc, typename C1GridDesc,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3) ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch, typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy, typename AThreadwiseCopy,
typename BThreadwiseCopy, typename BThreadwiseCopy,
typename CThreadwiseCopy, typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer (need CThreadwiseCopy). // copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy) // if false, will write to C directly (no need CThreadwiseCopy)
> >
struct GridwiseGemmBiasActivationAddAvx2_MxN struct GridwiseGemmBiasActivationAddAvx2_MxN
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats // static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit static constexpr index_t MemAlignmentByte = 32; // 256bit
static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk) static auto GetABlockDescriptor(const ck::index_t m_per_blk,
{ const ck::index_t k_per_blk,
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, const AGridDesc& a_grid_desc)
ck::tensor_layout::gemm::RowMajor>::value) {
{ if constexpr(UseALocalBuffer)
// A : M, K {
auto a_block_desc_m_k = if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk)); ck::tensor_layout::gemm::RowMajor>::value)
return a_block_desc_m_k; {
} // A : M, K
else auto a_block_desc_m_k =
{ make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
// A : K, M return a_block_desc_m_k;
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed( }
make_tuple(k_per_blk, else
math::integer_least_multiple( {
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize))); // A : K, M
return a_block_desc_k_m; auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
} make_tuple(k_per_blk,
} math::integer_least_multiple(
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk) return a_block_desc_k_m;
{ }
// n_per_blk should be 8x }
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, else
ck::tensor_layout::gemm::RowMajor>::value) {
{ return a_grid_desc;
// B : K, N }
auto b_block_desc_k_n = }
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
return b_block_desc_k_n; static auto GetBBlockDescriptor(const ck::index_t k_per_blk,
} const ck::index_t n_per_blk,
else const BGridDesc& b_grid_desc)
{ {
// B : N/8, K, N8 if constexpr(UseBLocalBuffer)
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple( {
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), // n_per_blk should be 8x
k_per_blk, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); ck::tensor_layout::gemm::RowMajor>::value)
return b_block_desc_n0_k_n1; {
} // B : K, N
} auto b_block_desc_k_n =
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
static auto GetCBlockDescriptor(const ck::index_t m_per_blk, return b_block_desc_k_n;
const ck::index_t n_per_blk, }
const CGridDesc& c_grid_desc) else
{ {
if constexpr(UseCLocalBuffer) // B : N/8, K, N8
{ auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk)); make_tuple(math::integer_divide_ceil(
} n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
else k_per_blk,
return c_grid_desc; ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
} return b_block_desc_n0_k_n1;
}
static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk) }
{ else
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, {
ck::tensor_layout::gemm::RowMajor>::value) return b_grid_desc;
{ }
// A : M, K }
return ck::make_multi_index(m_per_blk, k_per_blk);
} static auto GetCBlockDescriptor(const ck::index_t m_per_blk,
else const ck::index_t n_per_blk,
{ const CGridDesc& c_grid_desc)
// A : K, M {
return ck::make_multi_index( if constexpr(UseCLocalBuffer)
k_per_blk, {
math::integer_least_multiple(m_per_blk, return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)); }
} else
} return c_grid_desc;
}
static auto GetBSliceLength(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{ static auto GetASliceLength(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
// n_per_blk should be 8x {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// B : K, N // A : M, K
return ck::make_multi_index( return ck::make_multi_index(m_per_blk, k_per_blk);
k_per_blk, }
math::integer_least_multiple(n_per_blk, else
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); {
} // A : K, M
else return ck::make_multi_index(
{ k_per_blk,
// B : N/8, K, N8 math::integer_least_multiple(m_per_blk,
return ck::make_multi_index( ThreadwiseGemm_Dispatch::MatrixAMinVectorSize));
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), }
k_per_blk, }
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
} static auto GetBSliceLength(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
} {
// n_per_blk should be 8x
static auto GetCSliceLength(const ck::index_t m_per_blk, const ck::index_t n_per_blk) if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
{ ck::tensor_layout::gemm::RowMajor>::value)
return ck::make_multi_index(m_per_blk, n_per_blk); {
} // B : K, N
return ck::make_multi_index(
static auto GetAIndex(const ck::index_t i_m, const ck::index_t i_k) k_per_blk,
{ math::integer_least_multiple(n_per_blk,
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
ck::tensor_layout::gemm::RowMajor>::value) }
{ else
// A : M, K {
return ck::make_multi_index(i_m, i_k); // B : N/8, K, N8
} return ck::make_multi_index(
else math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
{ k_per_blk,
// A : K, M ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return ck::make_multi_index(i_k, i_m); }
} }
}
static auto GetCSliceLength(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
static auto GetBIndex(const ck::index_t i_k, const ck::index_t i_n) {
{ return ck::make_multi_index(m_per_blk, n_per_blk);
// i_n should be 8x }
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) static auto GetAIndex(const ck::index_t i_m, const ck::index_t i_k)
{ {
// B : K, N if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
return ck::make_multi_index(i_k, i_n); ck::tensor_layout::gemm::RowMajor>::value)
} {
else // A : M, K
{ return ck::make_multi_index(i_m, i_k);
// B : N/8, K, N8 }
return ck::make_multi_index(i_n / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize, else
i_k, {
i_n % ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); // A : K, M
} return ck::make_multi_index(i_k, i_m);
} }
}
static auto GetCIndex(const ck::index_t i_m, const ck::index_t i_n)
{ static auto GetBIndex(const ck::index_t i_k, const ck::index_t i_n)
return ck::make_multi_index(i_m, i_n); {
} // i_n should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, ck::tensor_layout::gemm::RowMajor>::value)
const BGridDesc& b_grid_desc, {
const CGridDesc& c_grid_desc) // B : K, N
{ return ck::make_multi_index(i_k, i_n);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) }
bool is_valid = true; else
const auto GemmN = c_grid_desc.GetLength(I1); {
if constexpr(UseCLocalBuffer) // B : N/8, K, N8
{ return ck::make_multi_index(i_n / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN) i_k,
is_valid &= false; i_n % ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
} }
else }
{
// TODO: need check c grid is simple transform? static auto GetCIndex(const ck::index_t i_m, const ck::index_t i_n)
if(GemmN % 8 != 0) {
is_valid &= false; return ck::make_multi_index(i_m, i_n);
} }
return is_valid;
} static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
static void Run(const FloatA* __restrict__ p_a_grid, const CGridDesc& c_grid_desc)
const FloatB* __restrict__ p_b_grid, {
FloatC* __restrict__ p_c_grid, // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
const FloatC0* __restrict__ p_c0_grid, bool is_valid = true;
const FloatC1* __restrict__ p_c1_grid, const auto GemmN = c_grid_desc.GetLength(I1);
const AGridDesc& a_grid_desc, if constexpr(UseCLocalBuffer)
const BGridDesc& b_grid_desc, {
const CGridDesc& c_grid_desc, if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
const C0GridDesc& c0_grid_desc, is_valid &= false;
const C1GridDesc& c1_grid_desc, }
const AElementwiseOperation& a_element_op, else
const BElementwiseOperation& b_element_op, {
const CElementwiseOperation& c_element_op) // TODO: need check c grid is simple transform?
{ if(GemmN % 8 != 0)
ck::index_t m_per_block = MPerBlock; is_valid &= false;
ck::index_t n_per_block = NPerBlock; }
ck::index_t k_per_block = KPerBlock; return is_valid;
}
const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1); static void Run(const FloatA* __restrict__ p_a_grid,
const auto GemmK = a_grid_desc.GetLength(I1); const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension(); const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension(); const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( const CGridDesc& c_grid_desc,
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize()); const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc,
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( const AElementwiseOperation& a_element_op,
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize()); const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( {
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize()); ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block = NPerBlock;
auto c0_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( ck::index_t k_per_block = KPerBlock;
reinterpret_cast<const FloatC0*>(p_c0_grid), c0_grid_desc.GetElementSpaceSize());
const auto GemmM = c_grid_desc.GetLength(I0);
auto c1_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( const auto GemmN = c_grid_desc.GetLength(I1);
reinterpret_cast<const FloatC1*>(p_c1_grid), c1_grid_desc.GetElementSpaceSize()); const auto GemmK = a_grid_desc.GetLength(I1);
auto blockwise_gemm = BlockwiseGemmAvx2_MxN< constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension();
FloatA, // FloatA,
FloatB, // FloatB, constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc, auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc, const_cast<FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock, auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, const_cast<FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{}; auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
int total_threads = omp_get_max_threads();
auto c0_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
#if 0 reinterpret_cast<const FloatC0*>(p_c0_grid), c0_grid_desc.GetElementSpaceSize());
if(total_threads > 1){
#pragma omp parallel auto c1_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
{ reinterpret_cast<const FloatC1*>(p_c1_grid), c1_grid_desc.GetElementSpaceSize());
int tid = omp_get_thread_num();
cpu_set_t set; auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
CPU_ZERO(&set); FloatA, // FloatA,
FloatB, // FloatB,
CPU_SET(tid, &set); FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc,
if (sched_setaffinity(0, sizeof(set), &set) == -1) { decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc,
throw std::runtime_error("wrong! fail to set thread affinity"); decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
} KPerBlock, // KPerBlock,
} ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
} ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
#endif // gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering int total_threads = omp_get_max_threads();
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value) #if 0
{ if(total_threads > 1){
auto a_move_k_step = GetAIndex(0, k_per_block); #pragma omp parallel
auto b_move_k_step = GetBIndex(k_per_block, 0); {
int tid = omp_get_thread_num();
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block); cpu_set_t set;
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block); CPU_ZERO(&set);
const ck::index_t grid_size = grid_m * grid_n;
const ck::index_t grids_per_thread = CPU_SET(tid, &set);
math::integer_divide_ceil(grid_size, total_threads);
if (sched_setaffinity(0, sizeof(set), &set) == -1) {
// This version does not consider K panel re-usage. simple for openmp throw std::runtime_error("wrong! fail to set thread affinity");
#pragma omp parallel }
{ }
auto a_threadwise_copy = }
AThreadwiseCopy(a_grid_desc, #endif
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block), // TODO: openmp aware ordering
ck::make_zero_multi_index<a_block_copy_dim>(), //
AElementwiseOperation{}); if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{
auto b_threadwise_copy = auto a_move_k_step = GetAIndex(0, k_per_block);
BThreadwiseCopy(b_grid_desc, auto b_move_k_step = GetBIndex(k_per_block, 0);
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block), const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
ck::make_zero_multi_index<b_block_copy_dim>(), const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
BElementwiseOperation{}); const ck::index_t grid_size = grid_m * grid_n;
const ck::index_t grids_per_thread =
auto c_threadwise_copy = math::integer_divide_ceil(grid_size, total_threads);
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::make_zero_multi_index<2>(), // This version does not consider K panel re-usage. simple for openmp
c_grid_desc, #pragma omp parallel
ck::make_zero_multi_index<2>(), {
CElementwiseOperation{}); auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA), ck::make_zero_multi_index<a_block_copy_dim>(),
MemAlignmentByte); GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), ck::make_zero_multi_index<a_block_copy_dim>(),
MemAlignmentByte); AElementwiseOperation{});
DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0, auto b_threadwise_copy =
MemAlignmentByte); BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc),
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf), ck::make_zero_multi_index<b_block_copy_dim>(),
a_block_mem.mMemSize / sizeof(FloatA)); BElementwiseOperation{});
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto c_threadwise_copy =
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf), CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
b_block_mem.mMemSize / sizeof(FloatB)); ck::make_zero_multi_index<2>(),
c_grid_desc,
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( ck::make_zero_multi_index<2>(),
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf) CElementwiseOperation{});
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC) DeviceAlignedMemCPU a_block_mem(
: c_grid_desc.GetElementSpaceSize()); UseALocalBuffer ? m_per_block * k_per_block * sizeof(FloatA) : 0,
MemAlignmentByte);
const ck::index_t tid = omp_get_thread_num(); DeviceAlignedMemCPU b_block_mem(
UseBLocalBuffer ? k_per_block * n_per_block * sizeof(FloatB) : 0,
for(ck::index_t i_gpt = 0; i_gpt < grids_per_thread; i_gpt++) MemAlignmentByte);
{ DeviceAlignedMemCPU c_block_mem(
ck::index_t gid = i_gpt * total_threads + tid; UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
if(gid >= grid_size) MemAlignmentByte);
break;
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ck::index_t i_mc = (gid / grid_n) * m_per_block; UseALocalBuffer ? reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf)
ck::index_t i_nc = (gid % grid_n) * n_per_block; : const_cast<FloatA*>(p_a_grid),
UseALocalBuffer ? a_block_mem.mMemSize / sizeof(FloatA)
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block); : a_grid_desc.GetElementSpaceSize());
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
nc_size = math::integer_least_multiple( UseBLocalBuffer ? reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf)
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); : const_cast<FloatB*>(p_b_grid),
UseBLocalBuffer ? b_block_mem.mMemSize / sizeof(FloatB)
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0)); : b_grid_desc.GetElementSpaceSize());
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc); UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
c_threadwise_copy.SetSrc1SliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc)); UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
c_threadwise_copy.SetSrc2SliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc)); : c_grid_desc.GetElementSpaceSize());
if constexpr(!UseCLocalBuffer)
{ const ck::index_t tid = omp_get_thread_num();
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunRead(c_grid_desc, for(ck::index_t i_gpt = 0; i_gpt < grids_per_thread; i_gpt++)
c_grid_buf, {
c0_grid_desc, ck::index_t gid = i_gpt * total_threads + tid;
c0_grid_buf, if(gid >= grid_size)
c1_grid_desc, break;
c1_grid_buf,
c_block_desc, ck::index_t i_mc = (gid / grid_n) * m_per_block;
c_block_buf, ck::index_t i_nc = (gid % grid_n) * n_per_block;
GetCSliceLength(mc_size, nc_size));
} ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
ck::index_t nc_size =
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
{ nc_size = math::integer_least_multiple(
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(0, i_nc));
a_threadwise_copy.RunRead(a_grid_desc, auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
a_grid_buf,
a_block_desc, c_threadwise_copy.SetSrc1SliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
a_block_buf, c_threadwise_copy.SetSrc2SliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
GetASliceLength(mc_size, kc_size)); if constexpr(!UseCLocalBuffer)
b_threadwise_copy.RunRead(b_grid_desc, {
b_grid_buf, c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, GetCIndex(i_mc, i_nc));
b_block_desc, c_threadwise_copy.RunRead(c_grid_desc,
b_block_buf, c_grid_buf,
GetBSliceLength(kc_size, nc_size)); c0_grid_desc,
c0_grid_buf,
blockwise_gemm.Run(a_block_desc, c1_grid_desc,
a_block_buf, c1_grid_buf,
make_zero_multi_index<a_block_copy_dim>(), c_block_desc,
b_block_desc, c_block_buf,
b_block_buf, GetCSliceLength(mc_size, nc_size));
make_zero_multi_index<b_block_copy_dim>(), }
c_block_desc,
c_block_buf, for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
make_zero_multi_index<2>(), {
i_kc != 0); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
if((i_kc + k_per_block) < GemmK) auto a_block_desc = GetABlockDescriptor(mc_size, kc_size, a_grid_desc);
{ auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size, b_grid_desc);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step); a_threadwise_copy.RunRead(a_grid_desc,
} a_grid_buf,
} a_block_desc,
a_block_buf,
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, GetCIndex(i_mc, i_nc)); GetASliceLength(mc_size, kc_size));
c_threadwise_copy.RunWrite(c_block_desc, b_threadwise_copy.RunRead(b_grid_desc,
c_block_buf, b_grid_buf,
c0_grid_desc, b_block_desc,
c0_grid_buf, b_block_buf,
c1_grid_desc, GetBSliceLength(kc_size, nc_size));
c1_grid_buf,
c_grid_desc, blockwise_gemm.Run(a_block_desc,
c_grid_buf, a_block_buf,
GetCSliceLength(mc_size, nc_size)); make_zero_multi_index<a_block_copy_dim>(),
} GetASliceLength(mc_size, kc_size),
}
} b_block_desc,
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value) b_block_buf,
{ make_zero_multi_index<b_block_copy_dim>(),
auto a_move_k_step = GetAIndex(0, k_per_block); GetBSliceLength(kc_size, nc_size),
auto b_move_k_step = GetBIndex(0, n_per_block);
c_block_desc,
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block); c_block_buf,
const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads); make_zero_multi_index<2>(),
GetCSliceLength(mc_size, nc_size),
// only parallel in gemm m dim
#pragma omp parallel i_kc != 0);
{
auto a_threadwise_copy = if((i_kc + k_per_block) < GemmK)
AThreadwiseCopy(a_grid_desc, {
ck::make_zero_multi_index<a_block_copy_dim>(), a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
GetABlockDescriptor(m_per_block, k_per_block), b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
ck::make_zero_multi_index<a_block_copy_dim>(), }
AElementwiseOperation{}); }
auto b_threadwise_copy = c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, GetCIndex(i_mc, i_nc));
BThreadwiseCopy(b_grid_desc, c_threadwise_copy.RunWrite(c_block_desc,
ck::make_zero_multi_index<b_block_copy_dim>(), c_block_buf,
GetBBlockDescriptor(k_per_block, n_per_block), c0_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(), c0_grid_buf,
BElementwiseOperation{}); c1_grid_desc,
c1_grid_buf,
auto c_threadwise_copy = c_grid_desc,
CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc), c_grid_buf,
ck::make_zero_multi_index<2>(), GetCSliceLength(mc_size, nc_size));
c_grid_desc, }
ck::make_zero_multi_index<2>(), }
CElementwiseOperation{}); }
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA), {
MemAlignmentByte); auto a_move_k_step = GetAIndex(0, k_per_block);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), auto b_move_k_step = GetBIndex(0, n_per_block);
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem( const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0, const ck::index_t grid_m_per_thread = math::integer_divide_ceil(grid_m, total_threads);
MemAlignmentByte);
// only parallel in gemm m dim
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( #pragma omp parallel
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf), {
a_block_mem.mMemSize / sizeof(FloatA)); auto a_threadwise_copy =
AThreadwiseCopy(a_grid_desc,
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( ck::make_zero_multi_index<a_block_copy_dim>(),
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf), GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc),
b_block_mem.mMemSize / sizeof(FloatB)); ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf) auto b_threadwise_copy =
: reinterpret_cast<FloatC*>(p_c_grid), BThreadwiseCopy(b_grid_desc,
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC) ck::make_zero_multi_index<b_block_copy_dim>(),
: c_grid_desc.GetElementSpaceSize()); GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc),
ck::make_zero_multi_index<b_block_copy_dim>(),
const ck::index_t tid = omp_get_thread_num(); BElementwiseOperation{});
for(ck::index_t i_gmpt = 0; i_gmpt < grid_m_per_thread; i_gmpt++) auto c_threadwise_copy =
{ CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc),
ck::index_t i_mc = (i_gmpt * total_threads + tid) * m_per_block; ck::make_zero_multi_index<2>(),
if(i_mc >= GemmM) c_grid_desc,
break; ck::make_zero_multi_index<2>(),
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block); CElementwiseOperation{});
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block) DeviceAlignedMemCPU a_block_mem(
{ UseALocalBuffer ? m_per_block * k_per_block * sizeof(FloatA) : 0,
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block); MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); UseBLocalBuffer ? k_per_block * n_per_block * sizeof(FloatB) : 0,
a_threadwise_copy.RunRead(a_grid_desc, MemAlignmentByte);
a_grid_buf, DeviceAlignedMemCPU c_block_mem(
a_block_desc, UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
a_block_buf, MemAlignmentByte);
GetASliceLength(mc_size, kc_size));
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(i_kc, 0)); UseALocalBuffer ? reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf)
: const_cast<FloatA*>(p_a_grid),
// TODO: if use local C buffer, then this nc loop need to loop only once UseALocalBuffer ? a_block_mem.mMemSize / sizeof(FloatA)
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block) : a_grid_desc.GetElementSpaceSize());
{
ck::index_t nc_size = auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x UseBLocalBuffer ? reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf)
nc_size = math::integer_least_multiple( : const_cast<FloatB*>(p_b_grid),
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize); UseBLocalBuffer ? b_block_mem.mMemSize / sizeof(FloatB)
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); : b_grid_desc.GetElementSpaceSize());
b_threadwise_copy.RunRead(b_grid_desc, auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
b_grid_buf, UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
b_block_desc, : reinterpret_cast<FloatC*>(p_c_grid),
b_block_buf, UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
GetBSliceLength(kc_size, nc_size)); : c_grid_desc.GetElementSpaceSize());
auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc); const ck::index_t tid = omp_get_thread_num();
c_threadwise_copy.SetSrc1SliceOrigin(c_block_desc, for(ck::index_t i_gmpt = 0; i_gmpt < grid_m_per_thread; i_gmpt++)
GetCIndex(i_mc, i_nc)); {
c_threadwise_copy.SetSrc2SliceOrigin(c_block_desc, ck::index_t i_mc = (i_gmpt * total_threads + tid) * m_per_block;
GetCIndex(i_mc, i_nc)); if(i_mc >= GemmM)
break;
if constexpr(!UseCLocalBuffer) ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
{ a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, GetAIndex(i_mc, 0));
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc, for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
GetCIndex(i_mc, i_nc)); {
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
c_threadwise_copy.RunRead(c_grid_desc,
c_grid_buf, auto a_block_desc = GetABlockDescriptor(mc_size, kc_size, a_grid_desc);
c0_grid_desc, a_threadwise_copy.RunRead(a_grid_desc,
c0_grid_buf, a_grid_buf,
c1_grid_desc, a_block_desc,
c1_grid_buf, a_block_buf,
c_block_desc, GetASliceLength(mc_size, kc_size));
c_block_buf,
GetCSliceLength(mc_size, nc_size)); b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc, GetBIndex(i_kc, 0));
}
// TODO: if use local C buffer, then this nc loop need to loop only once
blockwise_gemm.Run(a_block_desc, for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
a_block_buf, {
make_zero_multi_index<a_block_copy_dim>(), ck::index_t nc_size =
b_block_desc, ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
b_block_buf, nc_size = math::integer_least_multiple(
make_zero_multi_index<b_block_copy_dim>(), nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
c_block_desc, auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size, b_grid_desc);
c_block_buf,
make_zero_multi_index<2>(), b_threadwise_copy.RunRead(b_grid_desc,
i_kc != 0); b_grid_buf,
b_block_desc,
if((i_nc + n_per_block) < GemmN) b_block_buf,
{ GetBSliceLength(kc_size, nc_size));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
} auto c_block_desc = GetCBlockDescriptor(mc_size, nc_size, c_grid_desc);
if constexpr(UseCLocalBuffer) c_threadwise_copy.SetSrc1SliceOrigin(c_block_desc,
{ GetCIndex(i_mc, i_nc));
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, c_threadwise_copy.SetSrc2SliceOrigin(c_block_desc,
GetCIndex(i_mc, i_nc)); GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc, if constexpr(!UseCLocalBuffer)
c_block_buf, {
c0_grid_desc, c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
c0_grid_buf, GetCIndex(i_mc, i_nc));
c1_grid_desc,
c1_grid_buf, c_threadwise_copy.RunRead(c_grid_desc,
c_grid_desc, c_grid_buf,
c_grid_buf, c0_grid_desc,
GetCSliceLength(mc_size, nc_size)); c0_grid_buf,
} c1_grid_desc,
else c1_grid_buf,
{ c_block_desc,
// only write for last K, since the RunWrite here is just doing c_block_buf,
// elementwise op from global to global GetCSliceLength(mc_size, nc_size));
if((i_kc + k_per_block) >= GemmK) }
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc, blockwise_gemm.Run(a_block_desc,
GetCIndex(i_mc, i_nc)); a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
c_threadwise_copy.RunWrite(c_block_desc, GetASliceLength(mc_size, kc_size),
c_block_buf,
c0_grid_desc, b_block_desc,
c0_grid_buf, b_block_buf,
c1_grid_desc, make_zero_multi_index<b_block_copy_dim>(),
c1_grid_buf, GetBSliceLength(kc_size, nc_size),
c_grid_desc,
c_grid_buf, c_block_desc,
GetCSliceLength(mc_size, nc_size)); c_block_buf,
} make_zero_multi_index<2>(),
} GetCSliceLength(mc_size, nc_size),
}
i_kc != 0);
if((i_kc + k_per_block) < GemmK)
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step); if((i_nc + n_per_block) < GemmN)
} {
} b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
} }
}
} if constexpr(UseCLocalBuffer)
}; {
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
} // namespace cpu GetCIndex(i_mc, i_nc));
} // namespace ck
c_threadwise_copy.RunWrite(c_block_desc,
#endif c_block_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_grid_desc,
c_grid_buf,
GetCSliceLength(mc_size, nc_size));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if((i_kc + k_per_block) >= GemmK)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
GetCIndex(i_mc, i_nc));
c_threadwise_copy.RunWrite(c_block_desc,
c_block_buf,
c0_grid_desc,
c0_grid_buf,
c1_grid_desc,
c1_grid_buf,
c_grid_desc,
c_grid_buf,
GetCSliceLength(mc_size, nc_size));
}
}
}
if((i_kc + k_per_block) < GemmK)
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
}
}
}
}
}
};
} // namespace cpu
} // namespace ck
#endif
...@@ -519,7 +519,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -519,7 +519,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths> template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc& src_desc, void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf, SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const SliceLengths& slice_length) const SliceLengths& slice_length)
...@@ -917,14 +917,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC ...@@ -917,14 +917,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths> template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&, void RunRead(const SrcDesc&,
const SrcBuffer& src_buf, SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const SliceLengths& slice_length) const SliceLengths& slice_length)
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
// TODO: weight NHWC not support this // KYXC weigh should not support this
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
} }
else else
{ {
...@@ -1132,12 +1133,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8 ...@@ -1132,12 +1133,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths> template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&, void RunRead(const SrcDesc&,
const SrcBuffer& src_buf, SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf, DstBuffer& dst_buf,
const SliceLengths& slice_length) const SliceLengths& slice_length)
{ {
if constexpr(BypassTransfer) {} if constexpr(BypassTransfer)
{
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
else else
{ {
const ck::index_t n0_per_block = slice_length[Number<0>{}]; const ck::index_t n0_per_block = slice_length[Number<0>{}];
......
...@@ -47,121 +47,138 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver ...@@ -47,121 +47,138 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf> DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 48, 24, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
......
...@@ -40,121 +40,146 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver ...@@ -40,121 +40,146 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf> DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 48, 24, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>; // DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
......
#include <stdlib.h> #include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp" #include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance { namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float; using InType = float;
using WeiType = float; using WeiType = float;
using OutType = float; using OutType = float;
using AccType = float; using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 = static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop = static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC = static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK; static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m> DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
// clang-format on DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, false, false), // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, false, false)>; DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
// clang-format on DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple< // use this in single thread, but gemm_n is not multiple of 8
// clang-format off using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false), // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>; DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
// clang-format on DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple< // time no local c is better...)
// clang-format off using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 128, 4, 24, true, true, true, false), // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>;
// clang-format on DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
{ DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances{}); DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
} DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c( // clang-format on
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
ck::tensor_operation::device::add_device_operation_instances( std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
instances, {
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{}); ck::tensor_operation::device::add_device_operation_instances(
} instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
{ std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
ck::tensor_operation::device::add_device_operation_instances( {
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances{}); ck::tensor_operation::device::add_device_operation_instances(
} instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{});
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance }
} // namespace device
} // namespace cpu void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
} // namespace tensor_operation std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
} // namespace ck {
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -40,69 +40,81 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver ...@@ -40,69 +40,81 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m> DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
// clang-format on // clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, false, false)>; DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8 // use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>; DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on // clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some // use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...) // time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple< using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 128, 4, 24, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)>; DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment