"include/ck/utility/get_id.hpp" did not exist on "d6d9a8e4cee89feef6758f825cfea1588fec16da"
Commit 9cefc261 authored by carlushuang's avatar carlushuang
Browse files

refactor device instance to use less template, more dynamic tunable

parent 6dfb4e78
......@@ -21,7 +21,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using F32 = float;
using F16 = ck::half_t;
......
......@@ -17,7 +17,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using F32 = float;
using F16 = ck::half_t;
......
......@@ -18,8 +18,6 @@ template <typename FloatA,
typename BBlockDesc,
typename CDesc,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch,
typename ThreadMNAccessOrder // how we acces gemm MN to utilize micro kernel
>
......@@ -83,8 +81,11 @@ struct BlockwiseGemmAvx2_MxN
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
// return b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
// 1>{}].GetUpperLengths()[Number<1>{}] *
// b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
// 1>{}].GetUpperLengths()[Number<2>{}];
return b_block_desc.GetLength(Number<1>{}) * b_block_desc.GetLength(Number<2>{});
}
}
......
......@@ -3,6 +3,7 @@
#include <iostream>
#include "device_base_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
namespace ck {
namespace tensor_operation {
......@@ -77,6 +78,24 @@ using DeviceConvFwdBiasActivationAddPtr =
WeiElementwiseOperation,
OutElementwiseOperation>>;
struct DeviceConvFwdDynamicTunable
{
ck::index_t m_per_block;
ck::index_t n_per_block;
ck::index_t k_per_block;
// ck::index_t m_per_thread;
// ck::index_t n_per_thread;
// bool use_a_local_buffer;
// bool use_b_local_buffer;
// bool use_c_local_buffer;
// ConvolutionForwardSpecialization_t forward_spec;
// ConvolutionForwardGemmKSpecialization_t gemm_k_spec;
ConvolutionForwardBlockLoopOverSpecialization_t loop_over_spec;
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
......
......@@ -30,11 +30,7 @@ template <typename InDataType,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
......@@ -65,17 +61,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
......@@ -106,45 +97,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
ck::index_t gemm_n_padded =
......@@ -576,6 +528,48 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
......@@ -620,20 +614,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
......@@ -711,11 +703,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
......@@ -738,6 +734,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -753,6 +750,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -780,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return true;
}
static bool IsSupportedArgument(const Argument& arg)
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
......@@ -811,7 +809,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false;
}
......@@ -825,7 +823,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -868,7 +866,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -908,7 +906,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
......@@ -925,8 +923,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
......
......@@ -30,11 +30,8 @@ template <typename InDataType,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
// ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
......@@ -65,17 +62,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
......@@ -518,7 +510,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
......@@ -530,10 +523,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
......@@ -545,7 +539,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
......@@ -597,20 +592,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
......@@ -687,12 +680,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
......@@ -715,6 +711,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -730,6 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -757,7 +755,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
return true;
}
static bool IsSupportedArgument(const Argument& arg)
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
......@@ -788,7 +786,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false;
}
......@@ -805,7 +803,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -848,7 +846,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -888,7 +886,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
......@@ -905,8 +903,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<< "DFwdAvx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
......
......@@ -29,11 +29,7 @@ template <typename InDataType,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
......@@ -64,17 +60,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
......@@ -514,7 +505,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
......@@ -526,7 +518,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
......@@ -538,7 +531,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
......@@ -590,20 +584,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
......@@ -680,12 +672,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
......@@ -708,6 +703,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -723,6 +719,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -750,7 +747,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
return true;
}
static bool IsSupportedArgument(const Argument& arg)
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
......@@ -781,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false;
}
......@@ -801,7 +798,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -844,7 +841,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -884,7 +881,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
......@@ -901,9 +898,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<< "DFwdAvx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
......
......@@ -32,11 +32,7 @@ template <typename InDataType,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
......@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
......@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
ck::index_t gemm_n_padded =
......@@ -598,6 +550,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
......@@ -650,20 +638,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
......@@ -755,11 +741,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
......@@ -787,6 +777,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -806,6 +797,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -837,7 +829,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
return true;
}
static bool IsSupportedArgument(const Argument& arg)
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
......@@ -868,7 +860,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false;
}
......@@ -882,7 +874,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -929,7 +921,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -973,7 +965,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
......@@ -990,8 +982,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
<< "DFwd_BAA_Avx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
......
......@@ -32,11 +32,7 @@ template <typename InDataType,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
......@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
......@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_n / 8, gemm_k, 8));
......@@ -575,6 +527,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
......@@ -627,20 +615,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
......@@ -732,11 +718,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
......@@ -764,6 +754,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -783,6 +774,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -814,7 +806,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
return true;
}
static bool IsSupportedArgument(const Argument& arg)
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
......@@ -845,7 +837,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false;
}
......@@ -862,7 +854,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -909,7 +901,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -953,7 +945,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
......@@ -970,8 +962,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
<< "DFwd_BAA_Avx2_NHWC_KYXCK8"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
......
......@@ -31,11 +31,7 @@ template <typename InDataType,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
......@@ -72,17 +68,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K(
const DeviceConvFwdDynamicTunable& dtune)
: gridwise_gemm(dtune)
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
......@@ -111,42 +102,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return CGridDesc{};
}
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
return make_naive_tensor_descriptor_packed(make_tuple(gemm_k, gemm_n));
......@@ -568,6 +523,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using C0GridDesc = remove_cvref_t<decltype(MakeBiasTensorDescriptor(1, 1))>;
using C1GridDesc = CGridDesc;
static constexpr auto GetInputBlockDescriptor()
{
if constexpr(UseALocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return AGridDesc{};
}
}
static constexpr auto GetWeightBlockDescriptor()
{
if constexpr(UseBLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return BGridDesc{};
}
}
static constexpr auto GetOutputBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(0, 0));
}
else
{
return CGridDesc{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
......@@ -620,20 +611,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
AElementwiseOperation, // AElementwiseOperation,
BElementwiseOperation, // BElementwiseOperation,
CElementwiseOperation, // CElementwiseOperation,
MPerBlock, // MPerBlock,
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
GridwiseGemm gridwise_gemm;
// Argument
struct Argument : public BaseArgument
{
......@@ -725,11 +714,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
{
using Argument = DeviceOp::Argument;
GridwiseGemm gridwise_gemm;
Invoker(const GridwiseGemm& gridwise_gemm_) : gridwise_gemm(gridwise_gemm_) {}
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
int nrepeat = 1)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
if(!gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_))
{
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
......@@ -757,6 +750,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -776,6 +770,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
gridwise_gemm,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
......@@ -807,7 +802,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
return true;
}
static bool IsSupportedArgument(const Argument& arg)
bool IsSupportedArgument(const Argument& arg)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
......@@ -838,7 +833,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
ConvForwardSpecialization !=
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
return false;
}
......@@ -858,7 +853,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
return gridwise_gemm.CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -905,7 +900,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
auto MakeInvoker() { return Invoker{gridwise_gemm}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -949,7 +944,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
return std::make_unique<Invoker>(Invoker{gridwise_gemm});
}
std::string GetTypeString() const override
......@@ -966,8 +961,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
<< "DFwd_BAA_Avx2_NHWC_YXCK"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<<"_BS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.loop_over_spec)
<< "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
......
......@@ -28,7 +28,8 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
void kernel_gemm_avx_mxn(const GridwiseGemm& gridwise_gemm,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc& a_grid_desc,
......@@ -38,7 +39,7 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
GridwiseGemm::Run(p_a_grid,
gridwise_gemm.Run(p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc,
......@@ -58,14 +59,10 @@ template <typename FloatA,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy,
typename BThreadwiseCopy,
typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer,
bool UseBLocalBuffer,
......@@ -75,12 +72,19 @@ template <typename FloatA,
>
struct GridwiseGemmAvx2_MxN
{
ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit
GridwiseGemmAvx2_MxN(
const ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable_)
: dynamic_tunable(dynamic_tunable_)
{
}
static auto GetABlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t k_per_blk,
const AGridDesc& a_grid_desc)
......@@ -238,7 +242,7 @@ struct GridwiseGemmAvx2_MxN
return ck::make_multi_index(i_m, i_n);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
{
......@@ -247,7 +251,12 @@ struct GridwiseGemmAvx2_MxN
const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer)
{
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
// if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value &&
// dynamic_tunable.gemm_n_per_block < GemmN)
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::
ConvolutionForwardBlockLoopOverSpecialization_t::LoopOver_MKN &&
dynamic_tunable.n_per_block < GemmN)
is_valid &= false;
}
else
......@@ -259,7 +268,7 @@ struct GridwiseGemmAvx2_MxN
return is_valid;
}
static void Run(const FloatA* __restrict__ p_a_grid,
void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc& a_grid_desc,
......@@ -267,11 +276,11 @@ struct GridwiseGemmAvx2_MxN
const CGridDesc& c_grid_desc,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
const CElementwiseOperation& c_element_op) const
{
ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block = NPerBlock;
ck::index_t k_per_block = KPerBlock;
ck::index_t m_per_block = dynamic_tunable.m_per_block;
ck::index_t n_per_block = dynamic_tunable.n_per_block;
ck::index_t k_per_block = dynamic_tunable.k_per_block;
const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1);
......@@ -297,7 +306,6 @@ struct GridwiseGemmAvx2_MxN
decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
......@@ -323,7 +331,9 @@ struct GridwiseGemmAvx2_MxN
// TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MNK)
{
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(k_per_block, 0);
......@@ -467,7 +477,9 @@ struct GridwiseGemmAvx2_MxN
}
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
else if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MKN)
{
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(0, n_per_block);
......
......@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid,
void kernel_gemm_bias_activation_add_avx_mxn(const GridwiseGemm& gridwise_gemm,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
......@@ -46,7 +47,7 @@ void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
GridwiseGemm::Run(p_a_grid,
gridwise_gemm.Run(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_grid,
......@@ -74,14 +75,10 @@ template <typename FloatA,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy,
typename BThreadwiseCopy,
typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer,
bool UseBLocalBuffer,
......@@ -91,12 +88,19 @@ template <typename FloatA,
>
struct GridwiseGemmBiasActivationAddAvx2_MxN
{
ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit
GridwiseGemmBiasActivationAddAvx2_MxN(
const ck::tensor_operation::cpu::device::DeviceConvFwdDynamicTunable dynamic_tunable_)
: dynamic_tunable(dynamic_tunable_)
{
}
static auto GetABlockDescriptor(const ck::index_t m_per_blk,
const ck::index_t k_per_blk,
const AGridDesc& a_grid_desc)
......@@ -254,7 +258,7 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return ck::make_multi_index(i_m, i_n);
}
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
{
......@@ -263,7 +267,11 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer)
{
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::
ConvolutionForwardBlockLoopOverSpecialization_t::LoopOver_MKN &&
dynamic_tunable.n_per_block < GemmN)
is_valid &= false;
}
else
......@@ -275,7 +283,7 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return is_valid;
}
static void Run(const FloatA* __restrict__ p_a_grid,
void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
......@@ -287,11 +295,11 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
const C1GridDesc& c1_grid_desc,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
const CElementwiseOperation& c_element_op) const
{
ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block = NPerBlock;
ck::index_t k_per_block = KPerBlock;
ck::index_t m_per_block = dynamic_tunable.m_per_block;
ck::index_t n_per_block = dynamic_tunable.n_per_block;
ck::index_t k_per_block = dynamic_tunable.k_per_block;
const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1);
......@@ -323,7 +331,6 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
decltype(GetABlockDescriptor(m_per_block, k_per_block, a_grid_desc)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block, b_grid_desc)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block, c_grid_desc)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
......@@ -349,7 +356,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
// TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MNK)
{
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(k_per_block, 0);
......@@ -505,7 +515,9 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
}
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
else if(dynamic_tunable.loop_over_spec ==
ck::tensor_operation::cpu::device::ConvolutionForwardBlockLoopOverSpecialization_t::
LoopOver_MKN)
{
auto a_move_k_step = GetAIndex(0, k_per_block);
auto b_move_k_step = GetBIndex(0, n_per_block);
......
#include <stdlib.h>
#include <utility>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
......@@ -48,18 +49,24 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
......@@ -70,11 +77,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
......@@ -85,12 +98,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
......@@ -117,10 +135,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
......@@ -131,11 +156,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
......@@ -146,12 +177,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tu
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
......@@ -178,48 +214,9 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances{});
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
} // namespace device_conv2d_fwd_avx2_instance
......
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp"
......@@ -41,20 +42,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
......@@ -65,11 +71,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
......@@ -80,12 +92,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
......@@ -104,7 +121,6 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
......@@ -113,10 +129,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
......@@ -127,11 +150,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
......@@ -142,12 +171,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
......@@ -170,7 +204,6 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
......@@ -179,49 +212,9 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances{});
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
} // namespace device_conv2d_fwd_avx2_instance
......
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
......@@ -40,20 +41,24 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
......@@ -64,11 +69,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
......@@ -79,12 +90,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
......@@ -112,10 +128,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
......@@ -126,11 +149,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
......@@ -141,12 +170,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances = std::tu
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances = std::tuple<
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
......@@ -178,48 +212,9 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances{});
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)
// clang-format on
));
}
} // namespace device_conv2d_fwd_avx2_instance
......
#include <stdlib.h>
#include <utility>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
......@@ -41,18 +42,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
......@@ -63,11 +71,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances =
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
......@@ -78,12 +92,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_inst
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
......@@ -110,29 +129,9 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
......
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
......@@ -41,21 +42,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
......@@ -66,12 +71,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances =
std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
......@@ -82,12 +92,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_in
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
......@@ -114,29 +129,9 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instanc
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
......
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
......@@ -40,21 +41,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
......@@ -65,11 +70,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances =
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
......@@ -80,12 +91,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_inst
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
......@@ -112,29 +128,9 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)
// clang-format on
));
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment