Commit afc7d431 authored by carlushuang's avatar carlushuang
Browse files

avx2 gemm now works for single thread

parent 07af8343
......@@ -13,21 +13,10 @@ namespace cpu {
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AccDataType,
typename ABlockDesc,
typename BBlockDesc,
typename CBlockDesc,
typename ABlockSliceLengths,
typename BBlockSliceLengths,
typename CBlockSliceLengths,
typename AThreadSliceLength,
typename BThreadSliceLength,
ck::index_t AThreadLoopOverDim, // thread slice loop over on block slice. 1d is enough for
// now
ck::index_t BThreadLoopOverDim,
typename CDesc,
ck::index_t KPerBlock,
......@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension();
static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CDesc::GetNumOfDimension();
using IndexA = MultiIndex<nDimA>;
using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CBlockDesc{}, IndexC{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc)
......@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBlockBuffer>
ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const
{
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t GetCLeadingElement(const CDesc& c_desc) const
{
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
}
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
return i_m;
}
}
ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return i_n;
}
else
{
// N/8 * K * 8
return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const
{
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf,
const IndexA& a_origin,
const IndexA& /* a_origin */,
const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf,
const IndexB& b_origin,
const CBlockDesc& c_block_desc,
CBlockBuffer& c_block_buf,
const IndexC& c_origin) const
{
constexpr auto m_n_block_length =
ck::Sequence<ABlockSliceLengths::At(AThreadLoopOverDim),
BBlockSliceLengths::At(BThreadLoopOverDim)>{};
constexpr auto m_n_thread_length =
ck::Sequence<AThreadSliceLength::At(AThreadLoopOverDim),
BThreadSliceLength::At(BThreadLoopOverDim)>{};
const IndexB& /* b_origin */,
constexpr auto m_n_access_length = m_n_block_length / m_n_thread_length;
const CDesc& c_desc,
CBuffer& c_buf,
const IndexC& /* c_origin */,
constexpr auto ordered_m_n_access_length =
container_reorder_given_new2old(m_n_access_length, ThreadMNAccessOrder{});
bool is_accumulate_c = true) const
{
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
constexpr auto a_block_idx_zeros =
typename uniform_sequence_gen<nDimA, 0>::type{}; // starting point of the block
constexpr auto b_block_idx_zeros = typename uniform_sequence_gen<nDimB, 0>::type{};
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
constexpr auto lda = GetLeadingElement(a_block_desc) * sizeof(FloatA);
constexpr auto ldb = GetLeadingElement(b_block_desc) * sizeof(FloatB);
constexpr auto ldc = GetLeadingElement(c_block_desc) * sizeof(FloatC);
const auto k_per_block = GetKPerBlock(a_block_desc);
const auto m_per_block = GetMPerBlock(a_block_desc);
const auto n_per_block = GetNPerBlock(b_block_desc);
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
ck::cpu::ThreadwiseGemmParam param;
param.Kr = KPerBlock;
param.Kr = k_per_block;
param.lda = lda;
param.ldb = ldb;
param.ldc = ldc;
param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0;
static_ford<decltype(ordered_m_n_access_length)>{}([&](auto ordered_idx) {
constexpr auto origin_m_n_idx = ordered_idx.ReorderGivenOld2New(ThreadMNAccessOrder{});
constexpr auto current_m_idx =
origin_m_n_idx.At(0) * AThreadSliceLength::At(AThreadLoopOverDim);
constexpr auto current_n_idx =
origin_m_n_idx.At(1) * BThreadSliceLength::At(BThreadLoopOverDim);
constexpr auto current_mr =
ck::math::min(m_n_block_length.At(0) - current_m_idx, m_n_thread_length.At(0));
constexpr auto current_nr =
ck::math::min(m_n_block_length.At(1) - current_n_idx, m_n_thread_length.At(1));
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
{
auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
constexpr auto a_block_idx =
a_block_idx_zeros.Modify(AThreadLoopOverDim, current_m_idx);
constexpr auto a_block_coord =
make_tensor_coordinate(a_block_desc, to_multi_index(a_origin + a_block_idx));
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr auto b_block_idx =
b_block_idx_zeros.Modify(BThreadLoopOverDim, current_n_idx);
constexpr auto b_block_coord =
make_tensor_coordinate(b_block_desc, to_multi_index(b_origin + b_block_idx));
for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
{
auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
constexpr auto c_block_coord =
make_tensor_coordinate(c_block_desc, to_multi_index(c_origin + origin_m_n_idx));
param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
param.p_a = &a_block_buf.p_data_[a_block_coord.GetOffset()];
param.p_b = &b_block_buf.p_data_[b_block_coord.GetOffset()];
param.p_c = &c_block_buf.p_data_[c_block_coord.GetOffset()];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
});
}
}
}
}
};
......
......@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC,
};
enum ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop,
NHWC_GemmKLoopOverC, // not merge c*y*x, and c % k_per_block == 0
};
enum ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver,
LoopOver_MNK,
LoopOver_MKN,
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
......
......@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
......@@ -23,20 +25,21 @@ namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch>
// bool IsGemmMPadded,
// bool IsGemmNPadded,
// bool IsGemmKPadded>
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
......@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
static constexpr auto GetWeightBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
static constexpr auto GetOutputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto wei_gemm_n_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
wei_gemm_n_k_grid_desc,
ck::make_tuple(ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_n_k_grid_desc.GetLength(I0) /
make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
make_pass_through_transform(gemm_k)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
wei_gemm_padn_k_grid_desc,
ck::make_tuple(
ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_n_k_grid_desc.GetLength(I1))),
ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
......@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
......@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = K;
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
......@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr bool UseCLocalBuffer = true;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType,
OutDataType, // OutDataType,
AccDataType, // AccDataType,
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
......@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ck::Sequence<0, 1, 2>, // BlockMNKAccessOrder,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
......@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType,
WeiDataType,
......@@ -591,7 +710,10 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = launch_and_time_cpu_kernel(kernel,
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
......@@ -605,7 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.a_grid_desc_.GetElementSpaceSize());
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
......@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
......@@ -749,15 +878,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< ">";
<< "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
// clang-format on
return str.str();
......
......@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_gemm_param.hpp"
namespace ck {
namespace cpu {
......@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_6x16_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
......@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_6x16_Store_C%=:\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
......@@ -424,6 +428,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
// clang-format off
if(param->accmulate_c){
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
......@@ -436,6 +441,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 2 ) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 4 ) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){
#pragma unroll
......@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
if constexpr (NonTemporalStore) {
_mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm1);
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
......@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_4x24_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_4x24_Store_C%=:\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
......@@ -960,6 +984,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
// clang-format off
if(param->accmulate_c) {
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
......@@ -972,6 +997,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 ) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 2 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 3 ) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){
#pragma unroll
......@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = {
{
ThreadwiseGemm_6x16_t::Run,
ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_1x16_t::Run,
},
{
ThreadwiseGemm_5x16_t::Run,
ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_2x16_t::Run,
},
{
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_3x16_t::Run,
},
{
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_4x16_t::Run,
},
{
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_5x16_t::Run,
},
{
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_6x16_t::Run,
},
};
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 5 && in >= 0 && in <= 1);
return dispatch_table[mr][nr](param);
}
};
......@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = {
{
ThreadwiseGemm_4x24_t::Run,
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x24_t::Run,
},
{
ThreadwiseGemm_3x24_t::Run,
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x24_t::Run,
},
{
ThreadwiseGemm_2x24_t::Run,
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x24_t::Run,
},
{
ThreadwiseGemm_1x24_t::Run,
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x24_t::Run,
},
};
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{
return dispatch_table[mr][nr](param);
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 3 && in >= 0 && in <= 2);
return dispatch_table[im][in](param);
}
};
......
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_GEMM_PARAM_HPP
#define CK_THREADWISE_GEMM_PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
......@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte
float alpha;
uint32_t _pack0;
int accmulate_c; // if 1, need load C and add into current fma. if 0, direct store out c result
} __attribute__((packed));
} // namespace cpu
......
......@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
int N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
int Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
int Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
int C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
int Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
int Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
int Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
int Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
int Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
int Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
int Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
int Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
int Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
int Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
}
void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
std::cout << "src hidden:" << SrcDesc::GetNumOfHiddenDimension() << std::endl;
std::cout << "dst hidden:" << DstDesc::GetNumOfHiddenDimension() << std::endl;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
......@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
#endif
const auto src_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto src_slice_step = make_tensor_coordinate_step(
src_desc, to_multi_index(src_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
const auto dst_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto dst_slice_step = make_tensor_coordinate_step(
dst_desc, to_multi_index(dst_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
for(auto idx_id = 0; idx_id < num_access; idx_id++)
{
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = ck::cpu::vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
printf("[%s] ", is_src_valid ? "y" : "n");
print_multi_index(src_coord_.GetIndex());
printf("----");
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_(dst_vector_container.template AsType<dst_vector_t>(),
src_vector_container.template AsType<src_vector_t>());
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
printf(" -> ");
print_multi_index(dst_coord_.GetIndex());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf("\n");
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>());
// move coordinate
if(idx_id != num_access - 1)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(src_desc, src_coord_, src_slice_step);
move_tensor_coordinate(dst_desc, dst_coord_, dst_slice_step);
}
}
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
......
......@@ -121,6 +121,10 @@ template <typename... Args, typename F>
float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
{
WallTimer timer;
int nwarmup = 3;
for(int i = 0; i < nwarmup; i++)
kernel(args...);
timer.Start();
......
......@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType,
......@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
256,
128,
64,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
512,
256,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
1024,
144,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>>;
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
......
......@@ -9,8 +9,11 @@
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
using F32 = float;
using F16 = ck::half_t;
namespace ck {
namespace tensor_operation {
......@@ -18,6 +21,8 @@ namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
......@@ -34,6 +39,7 @@ using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
int error_count = 0;
float max_diff = 1e-6;
for(int i = 0; i < ref.mData.size(); ++i)
......@@ -41,28 +47,35 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
return false;
error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
i,
double(ref.mData[i]),
double(result.mData[i]),
diff);
}
}
return true;
return error_count == 0;
}
float calculate_gflops() {}
int main(int argc, char* argv[])
{
int data_type = 0;
int init_method = 0;
// Conv shape
ck::index_t N = 128;
ck::index_t N = 2;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 2;
ck::index_t conv_stride_w = 2;
ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 1;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
......@@ -72,7 +85,7 @@ int main(int argc, char* argv[])
if(argc == 1)
{
data_type = 1;
data_type = 0;
init_method = 1;
}
else if(argc == 3)
......@@ -110,17 +123,14 @@ int main(int argc, char* argv[])
exit(1);
}
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) {
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using AccDataType = decltype(acc_type);
using ReferenceConvBwdInstance =
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
......@@ -147,53 +157,93 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
Tensor<OutDataType> out_n_ho_wo_k(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<WeiDataType> wei_k_y_x_c(f_host_tensor_descriptor(K, C, Y, X));
Tensor<InDataType> in_n_hi_wi_c_host_result(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<InDataType> in_n_hi_wi_c_device_result(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
std::cout << "in (N, C, Hi, Wi): " << in_n_hi_wi_c_host_result.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_y_x_c.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_ho_wo_k.mDesc << std::endl;
std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break;
case 2:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break;
case 3:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for(auto i_n = 0; i_n < N; i_n++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_hi = 0; i_hi < Hi; i_hi++)
{
for(auto i_wi = 0; i_wi < Wi; i_wi++)
{
uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
}
}
}
}
for(auto i_k = 0; i_k < K; i_k++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_y = 0; i_y < Y; i_y++)
{
for(auto i_x = 0; i_x < X; i_x++)
{
uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
}
}
}
}
break;
default:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
}
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) *
in_n_hi_wi_c_device_result.mDesc.GetElementSpace(),
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(
sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data());
// reset input to zero
in_n_hi_wi_c_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_hi_wi_c_device_result.mData.data());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c_host_result,
wei_k_y_x_c,
out_n_ho_wo_k,
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo_host_result,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -205,8 +255,8 @@ int main(int argc, char* argv[])
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
......@@ -226,6 +276,9 @@ int main(int argc, char* argv[])
// profile device Conv instances
bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
......@@ -249,18 +302,30 @@ int main(int argc, char* argv[])
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), 1);
double time = invoker_ptr->Run(argument_ptr.get(), 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
in_device_buf.FromDevice(in_n_hi_wi_c_device_result.mData.data());
double gflops = (total_flop * 1e-6) / time;
if(!check_out(in_n_hi_wi_c_host_result, in_n_hi_wi_c_device_result))
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
......@@ -269,25 +334,27 @@ int main(int argc, char* argv[])
}
}
if(success)
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << "test conv2d fwd cpu : Pass" << std::endl;
return 0;
}
else
{
std::cout << "test conv2d fwd cpu: Fail " << std::endl;
return -1;
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32(), F32());
}
else if(data_type == 1)
{
return Run(F16(), F16(), F16(), F32());
return Run(F32(), F32(), F32());
}
else
{
......
......@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static constexpr ck::index_t nDim =
ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension();
input_desc.Print();
auto threadwise_transfer = threadwise_transfer_t{input_desc,
ck::make_zero_multi_index<nDim>(),
input_cblock_desc,
......
......@@ -321,6 +321,7 @@ void test_ukernel(ukenrel_t uk,
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float);
param.alpha = alpha;
param.accmulate_c = 0;
memset(private_c, 0, m * n * sizeof(float));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment