Commit 9cefc261 authored by carlushuang's avatar carlushuang
Browse files

refactor device instance to use less template, more dynamic tunable

parent 6dfb4e78
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2 #define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK #define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2 #define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK #define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
......
...@@ -18,8 +18,6 @@ template <typename FloatA, ...@@ -18,8 +18,6 @@ template <typename FloatA,
typename BBlockDesc, typename BBlockDesc,
typename CDesc, typename CDesc,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch, typename ThreadwiseGemm_Dispatch,
typename ThreadMNAccessOrder // how we acces gemm MN to utilize micro kernel typename ThreadMNAccessOrder // how we acces gemm MN to utilize micro kernel
> >
...@@ -83,8 +81,11 @@ struct BlockwiseGemmAvx2_MxN ...@@ -83,8 +81,11 @@ struct BlockwiseGemmAvx2_MxN
else else
{ {
// N/8 * K * 8 // N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] * // return b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; // 1>{}].GetUpperLengths()[Number<1>{}] *
// b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
// 1>{}].GetUpperLengths()[Number<2>{}];
return b_block_desc.GetLength(Number<1>{}) * b_block_desc.GetLength(Number<2>{});
} }
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <iostream> #include <iostream>
#include "device_base_cpu.hpp" #include "device_base_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -77,6 +78,24 @@ using DeviceConvFwdBiasActivationAddPtr = ...@@ -77,6 +78,24 @@ using DeviceConvFwdBiasActivationAddPtr =
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation>>; OutElementwiseOperation>>;
struct DeviceConvFwdDynamicTunable
{
ck::index_t m_per_block;
ck::index_t n_per_block;
ck::index_t k_per_block;
// ck::index_t m_per_thread;
// ck::index_t n_per_thread;
// bool use_a_local_buffer;
// bool use_b_local_buffer;
// bool use_c_local_buffer;
// ConvolutionForwardSpecialization_t forward_spec;
// ConvolutionForwardGemmKSpecialization_t gemm_k_spec;
ConvolutionForwardBlockLoopOverSpecialization_t loop_over_spec;
};
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -30,11 +30,7 @@ template <typename InDataType, ...@@ -30,11 +30,7 @@ template <typename InDataType,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
...@@ -65,17 +61,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -65,17 +61,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
...@@ -106,45 +97,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -106,45 +97,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch()); using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
ck::index_t gemm_n_padded = ck::index_t gemm_n_padded =
...@@ -576,6 +528,48 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -576,6 +528,48 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false; // static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy = using AThreadwiseCopy =
...@@ -620,20 +614,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -620,20 +614,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AElementwiseOperation, // AElementwiseOperation, AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation, BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation, CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
GridwiseGemm gridwise_gemm;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -711,11 +703,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -711,11 +703,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg, float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}, const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
...@@ -738,6 +734,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -738,6 +734,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(nrepeat != 1) if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -753,6 +750,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -753,6 +750,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -780,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -780,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -811,7 +809,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -811,7 +809,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ConvForwardSpecialization != ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
} }
...@@ -825,7 +823,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -825,7 +823,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -868,7 +866,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -868,7 +866,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_element_op}; out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
...@@ -908,7 +906,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -908,7 +906,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -925,8 +923,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -925,8 +923,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< "DFwdAvx2_NHWC_KYXC" << "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
......
...@@ -30,11 +30,8 @@ template <typename InDataType, ...@@ -30,11 +30,8 @@ template <typename InDataType,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization, // ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
...@@ -65,17 +62,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -65,17 +62,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
...@@ -518,7 +510,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -518,7 +510,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{ {
if constexpr(UseALocalBuffer) if constexpr(UseALocalBuffer)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); // return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
} }
else else
{ {
...@@ -530,10 +523,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -530,10 +523,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{ {
if constexpr(UseBLocalBuffer) if constexpr(UseBLocalBuffer)
{ {
return make_naive_tensor_descriptor_packed(make_tuple( // return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), // math::integer_divide_ceil(NPerBlock,
KPerBlock, // ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); // ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
} }
else else
{ {
...@@ -545,7 +539,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -545,7 +539,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{ {
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); // return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
} }
else else
{ {
...@@ -597,20 +592,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -597,20 +592,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
AElementwiseOperation, // AElementwiseOperation, AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation, BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation, CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
GridwiseGemm gridwise_gemm;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -687,12 +680,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -687,12 +680,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg, float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}, const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
...@@ -715,6 +711,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -715,6 +711,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
if(nrepeat != 1) if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -730,6 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -730,6 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -757,7 +755,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -757,7 +755,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -788,7 +786,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -788,7 +786,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
ConvForwardSpecialization != ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
} }
...@@ -805,7 +803,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -805,7 +803,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -848,7 +846,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -848,7 +846,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
out_element_op}; out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
...@@ -888,7 +886,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -888,7 +886,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -905,8 +903,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K ...@@ -905,8 +903,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<< "DFwdAvx2_NHWC_KYXCK8" << "DFwdAvx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
......
...@@ -29,11 +29,7 @@ template <typename InDataType, ...@@ -29,11 +29,7 @@ template <typename InDataType,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
...@@ -64,17 +60,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -64,17 +60,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
...@@ -514,7 +505,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -514,7 +505,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{ {
if constexpr(UseALocalBuffer) if constexpr(UseALocalBuffer)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); // return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
} }
else else
{ {
...@@ -526,7 +518,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -526,7 +518,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{ {
if constexpr(UseBLocalBuffer) if constexpr(UseBLocalBuffer)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock)); // return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
} }
else else
{ {
...@@ -538,7 +531,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -538,7 +531,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{ {
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); // return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
} }
else else
{ {
...@@ -590,20 +584,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -590,20 +584,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
AElementwiseOperation, // AElementwiseOperation, AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation, BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation, CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
GridwiseGemm gridwise_gemm;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -680,12 +672,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -680,12 +672,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg, float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}, const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
...@@ -708,6 +703,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -708,6 +703,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
if(nrepeat != 1) if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -723,6 +719,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -723,6 +719,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -750,7 +747,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -750,7 +747,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -781,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -781,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
ConvForwardSpecialization != ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
} }
...@@ -801,7 +798,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -801,7 +798,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -844,7 +841,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -844,7 +841,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
out_element_op}; out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
...@@ -884,7 +881,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -884,7 +881,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -901,9 +898,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K ...@@ -901,9 +898,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<< "DFwdAvx2_NHWC_YXCK" << "DFwdAvx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer) << "_C" << string_local_buffer(UseCLocalBuffer)
......
...@@ -32,11 +32,7 @@ template <typename InDataType, ...@@ -32,11 +32,7 @@ template <typename InDataType,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
...@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
...@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch()); using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
ck::index_t gemm_n_padded = ck::index_t gemm_n_padded =
...@@ -598,6 +550,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -598,6 +550,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>; using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc; using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false; // static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy = using AThreadwiseCopy =
...@@ -650,20 +638,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -650,20 +638,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
AElementwiseOperation, // AElementwiseOperation, AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation, BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation, CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
GridwiseGemm gridwise_gemm;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -755,11 +741,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -755,11 +741,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg, float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}, const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
...@@ -787,6 +777,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -787,6 +777,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
if(nrepeat != 1) if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -806,6 +797,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -806,6 +797,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -837,7 +829,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -837,7 +829,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -868,7 +860,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -868,7 +860,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
ConvForwardSpecialization != ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
} }
...@@ -882,7 +874,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -882,7 +874,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -929,7 +921,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -929,7 +921,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
out_element_op}; out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
...@@ -973,7 +965,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -973,7 +965,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -990,8 +982,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -990,8 +982,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
<< "DFwd_BAA_Avx2_NHWC_KYXC" << "DFwd_BAA_Avx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
......
...@@ -32,11 +32,7 @@ template <typename InDataType, ...@@ -32,11 +32,7 @@ template <typename InDataType,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
...@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
...@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch()); using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n / 8, gemm_k, 8)); return make_naive_tensor_descriptor_packed(make_tuple(gemm_n / 8, gemm_k, 8));
...@@ -575,6 +527,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -575,6 +527,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>; using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc; using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false; // static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy = using AThreadwiseCopy =
...@@ -627,20 +615,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -627,20 +615,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AElementwiseOperation, // AElementwiseOperation, AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation, BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation, CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
GridwiseGemm gridwise_gemm;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -732,11 +718,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -732,11 +718,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg, float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}, const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
...@@ -764,6 +754,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -764,6 +754,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if(nrepeat != 1) if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -783,6 +774,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -783,6 +774,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -814,7 +806,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -814,7 +806,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -845,7 +837,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -845,7 +837,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
ConvForwardSpecialization != ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
} }
...@@ -862,7 +854,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -862,7 +854,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -909,7 +901,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -909,7 +901,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
out_element_op}; out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
...@@ -953,7 +945,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -953,7 +945,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -970,8 +962,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -970,8 +962,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
<< "DFwd_BAA_Avx2_NHWC_KYXCK8" << "DFwd_BAA_Avx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
......
...@@ -31,11 +31,7 @@ template <typename InDataType, ...@@ -31,11 +31,7 @@ template <typename InDataType,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization, ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread, ck::index_t MPerThread,
ck::index_t NPerThread, ck::index_t NPerThread,
bool UseALocalBuffer, bool UseALocalBuffer,
...@@ -72,17 +68,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -72,17 +68,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder() DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{ {
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
} }
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch() static constexpr auto GetThreadwiseGemm_Dispatch()
{ {
if constexpr(MPerThread == 4 && NPerThread == 24) if constexpr(MPerThread == 4 && NPerThread == 24)
...@@ -111,42 +102,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -111,42 +102,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch()); using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n)); return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
...@@ -568,6 +523,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -568,6 +523,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>; using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc; using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false; // static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy = using AThreadwiseCopy =
...@@ -620,20 +611,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -620,20 +611,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
AElementwiseOperation, // AElementwiseOperation, AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation, BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation, CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
GridwiseGemm gridwise_gemm;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -725,11 +714,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -725,11 +714,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg, float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}, const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1) int nrepeat = 1)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_)) if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{ {
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
...@@ -757,6 +750,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -757,6 +750,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
if(nrepeat != 1) if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel, ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -776,6 +770,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -776,6 +770,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
...@@ -807,7 +802,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -807,7 +802,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
return true; return true;
} }
static bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
...@@ -838,7 +833,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -838,7 +833,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
ConvForwardSpecialization != ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
if(!(arg.Conv_C_ % KPerBlock == 0)) if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false; return false;
} }
...@@ -858,7 +853,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -858,7 +853,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -905,7 +900,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -905,7 +900,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
out_element_op}; out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
...@@ -949,7 +944,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -949,7 +944,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{gridwise_gemm});
} }
std::string GetTypeString() const override std::string GetTypeString() const override
...@@ -966,8 +961,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -966,8 +961,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
<< "DFwd_BAA_Avx2_NHWC_YXCK" << "DFwd_BAA_Avx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization) <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization) <<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization) <<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer) << "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer) << "_B" << string_local_buffer(UseBLocalBuffer)
......
...@@ -28,7 +28,8 @@ template <typename GridwiseGemm, ...@@ -28,7 +28,8 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid, void kernel_gemm_avx_mxn(const GridwiseGemm& gridwise_gemm,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc& a_grid_desc, const AGridDesc& a_grid_desc,
...@@ -38,7 +39,7 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid, ...@@ -38,7 +39,7 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
GridwiseGemm::Run(p_a_grid, gridwise_gemm.Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_grid_desc, a_grid_desc,
...@@ -58,14 +59,10 @@ template <typename FloatA, ...@@ -58,14 +59,10 @@ template <typename FloatA,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch, typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy, typename AThreadwiseCopy,
typename BThreadwiseCopy, typename BThreadwiseCopy,
typename CThreadwiseCopy, typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
...@@ -75,12 +72,19 @@ template <typename FloatA, ...@@ -75,12 +72,19 @@ template <typename FloatA,
> >
struct GridwiseGemmAvx2_MxN struct GridwiseGemmAvx2_MxN
{ {
ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats // static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit static constexpr index_t MemAlignmentByte = 32; // 256bit
GridwiseGemmAvx2_MxN(
const ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable_)
: dynamic_tunable(dynamic_tunable_)
{
}
static auto GetABlockDescriptor(const ck::index_t m_per_blk, static auto GetABlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t k_per_blk, const ck::index_t k_per_blk,
const AGridDesc& a_grid_desc) const AGridDesc& a_grid_desc)
...@@ -238,16 +242,21 @@ struct GridwiseGemmAvx2_MxN ...@@ -238,16 +242,21 @@ struct GridwiseGemmAvx2_MxN
return ck::make_multi_index(i_m, i_n); return ck::make_multi_index(i_m, i_n);
} }
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) const CGridDesc& c_grid_desc)
{ {
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool is_valid = true; bool is_valid = true;
const auto GemmN = c_grid_desc.GetLength(I1); const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN) // if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value &&
// dynamic_tunable.gemm_n_per_block < GemmN)
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::
ConvolutionForwardBlockLoopOverSpecialization_t::LoopOver_MKN &&
dynamic_tunable.n_per_block < GemmN)
is_valid &= false; is_valid &= false;
} }
else else
...@@ -259,19 +268,19 @@ struct GridwiseGemmAvx2_MxN ...@@ -259,19 +268,19 @@ struct GridwiseGemmAvx2_MxN
return is_valid; return is_valid;
} }
static void Run(const FloatA* __restrict__ p_a_grid, void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc& a_grid_desc, const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc, const CGridDesc& c_grid_desc,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op) const
{ {
ck::index_t m_per_block = MPerBlock; ck::index_t m_per_block = dynamic_tunable.m_per_block;
ck::index_t n_per_block = NPerBlock; ck::index_t n_per_block = dynamic_tunable.n_per_block;
ck::index_t k_per_block = KPerBlock; ck::index_t k_per_block = dynamic_tunable.k_per_block;
const auto GemmM = c_grid_desc.GetLength(I0); const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1); const auto GemmN = c_grid_desc.GetLength(I1);
...@@ -297,7 +306,6 @@ struct GridwiseGemmAvx2_MxN ...@@ -297,7 +306,6 @@ struct GridwiseGemmAvx2_MxN
decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc, decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc, decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc, decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{}; // gemm MN to utilize micro kernel>{};
...@@ -323,7 +331,9 @@ struct GridwiseGemmAvx2_MxN ...@@ -323,7 +331,9 @@ struct GridwiseGemmAvx2_MxN
// TODO: openmp aware ordering // TODO: openmp aware ordering
// //
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value) if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MNK)
{ {
auto a_move_k_step = GetAIndex(0, k_per_block); auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(k_per_block, 0); auto b_move_k_step = GetBIndex(k_per_block, 0);
...@@ -467,7 +477,9 @@ struct GridwiseGemmAvx2_MxN ...@@ -467,7 +477,9 @@ struct GridwiseGemmAvx2_MxN
} }
} }
} }
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value) else if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MKN)
{ {
auto a_move_k_step = GetAIndex(0, k_per_block); auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(0, n_per_block); auto b_move_k_step = GetBIndex(0, n_per_block);
......
...@@ -32,7 +32,8 @@ template <typename GridwiseGemm, ...@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid, void kernel_gemm_bias_activation_add_avx_mxn(const GridwiseGemm& gridwise_gemm,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid, const FloatC0* __restrict__ p_c0_grid,
...@@ -46,7 +47,7 @@ void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid ...@@ -46,7 +47,7 @@ void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
GridwiseGemm::Run(p_a_grid, gridwise_gemm.Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid, p_c0_grid,
...@@ -74,14 +75,10 @@ template <typename FloatA, ...@@ -74,14 +75,10 @@ template <typename FloatA,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch, typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy, typename AThreadwiseCopy,
typename BThreadwiseCopy, typename BThreadwiseCopy,
typename CThreadwiseCopy, typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
...@@ -91,12 +88,19 @@ template <typename FloatA, ...@@ -91,12 +88,19 @@ template <typename FloatA,
> >
struct GridwiseGemmBiasActivationAddAvx2_MxN struct GridwiseGemmBiasActivationAddAvx2_MxN
{ {
ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats // static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit static constexpr index_t MemAlignmentByte = 32; // 256bit
GridwiseGemmBiasActivationAddAvx2_MxN(
const ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable_)
: dynamic_tunable(dynamic_tunable_)
{
}
static auto GetABlockDescriptor(const ck::index_t m_per_blk, static auto GetABlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t k_per_blk, const ck::index_t k_per_blk,
const AGridDesc& a_grid_desc) const AGridDesc& a_grid_desc)
...@@ -254,16 +258,20 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -254,16 +258,20 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return ck::make_multi_index(i_m, i_n); return ck::make_multi_index(i_m, i_n);
} }
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) const CGridDesc& c_grid_desc)
{ {
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool is_valid = true; bool is_valid = true;
const auto GemmN = c_grid_desc.GetLength(I1); const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
{ {
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN) if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::
ConvolutionForwardBlockLoopOverSpecialization_t::LoopOver_MKN &&
dynamic_tunable.n_per_block < GemmN)
is_valid &= false; is_valid &= false;
} }
else else
...@@ -275,23 +283,23 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -275,23 +283,23 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return is_valid; return is_valid;
} }
static void Run(const FloatA* __restrict__ p_a_grid, void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid, const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid, const FloatC1* __restrict__ p_c1_grid,
const AGridDesc& a_grid_desc, const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc, const CGridDesc& c_grid_desc,
const C0GridDesc& c0_grid_desc, const C0GridDesc& c0_grid_desc,
const C1GridDesc& c1_grid_desc, const C1GridDesc& c1_grid_desc,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op) const
{ {
ck::index_t m_per_block = MPerBlock; ck::index_t m_per_block = dynamic_tunable.m_per_block;
ck::index_t n_per_block = NPerBlock; ck::index_t n_per_block = dynamic_tunable.n_per_block;
ck::index_t k_per_block = KPerBlock; ck::index_t k_per_block = dynamic_tunable.k_per_block;
const auto GemmM = c_grid_desc.GetLength(I0); const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1); const auto GemmN = c_grid_desc.GetLength(I1);
...@@ -323,7 +331,6 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -323,7 +331,6 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc, decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc, decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc, decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{}; // gemm MN to utilize micro kernel>{};
...@@ -349,7 +356,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -349,7 +356,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
// TODO: openmp aware ordering // TODO: openmp aware ordering
// //
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MNK)
{ {
auto a_move_k_step = GetAIndex(0, k_per_block); auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(k_per_block, 0); auto b_move_k_step = GetBIndex(k_per_block, 0);
...@@ -505,7 +515,9 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN ...@@ -505,7 +515,9 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
} }
} }
} }
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value) else if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MKN)
{ {
auto a_move_k_step = GetAIndex(0, k_per_block); auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(0, n_per_block); auto b_move_k_step = GetBIndex(0, n_per_block);
......
#include <stdlib.h> #include <stdlib.h>
#include <utility>
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp" #include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
...@@ -41,83 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -41,83 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m> DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
// clang-format on DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances{}); instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on
));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
...@@ -125,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c( ...@@ -125,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{}); std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances{}); instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
} }
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
......
#include <stdlib.h> #include <stdlib.h>
#include <utility>
#include "config.hpp" #include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp" #include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
...@@ -41,87 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -41,87 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m> DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances =
std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances{}); instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on
));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
...@@ -129,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( ...@@ -129,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances{}); std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances{}); instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
} }
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
......
#include <stdlib.h> #include <stdlib.h>
#include <utility>
#include "config.hpp" #include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp" #include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
...@@ -40,86 +41,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -40,86 +41,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m> DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances{}); instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on
));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
...@@ -127,14 +80,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c( ...@@ -127,14 +80,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances{}); std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances{}); instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
} }
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment