Commit afc7d431 authored by carlushuang's avatar carlushuang
Browse files

avx2 gemm now works for single thread

parent 07af8343
...@@ -13,21 +13,10 @@ namespace cpu { ...@@ -13,21 +13,10 @@ namespace cpu {
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AccDataType,
typename ABlockDesc, typename ABlockDesc,
typename BBlockDesc, typename BBlockDesc,
typename CBlockDesc, typename CDesc,
typename ABlockSliceLengths,
typename BBlockSliceLengths,
typename CBlockSliceLengths,
typename AThreadSliceLength,
typename BThreadSliceLength,
ck::index_t AThreadLoopOverDim, // thread slice loop over on block slice. 1d is enough for
// now
ck::index_t BThreadLoopOverDim,
ck::index_t KPerBlock, ck::index_t KPerBlock,
...@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN ...@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension(); static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension();
static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension(); static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CBlockDesc::GetNumOfDimension(); static constexpr index_t nDimC = CDesc::GetNumOfDimension();
using IndexA = MultiIndex<nDimA>; using IndexA = MultiIndex<nDimA>;
using IndexB = MultiIndex<nDimB>; using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>; using IndexC = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{})); using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{})); using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CBlockDesc{}, IndexC{})); using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
template <typename TensorDesc> template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc) constexpr auto GetLeadingElement(const TensorDesc& desc)
...@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN ...@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
} }
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CBlockBuffer> ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const
{
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t GetCLeadingElement(const CDesc& c_desc) const
{
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
}
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
return i_m;
}
}
ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return i_n;
}
else
{
// N/8 * K * 8
return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const
{
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc, void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const IndexA& a_origin, const IndexA& /* a_origin */,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
const IndexB& b_origin, const IndexB& /* b_origin */,
const CBlockDesc& c_block_desc,
CBlockBuffer& c_block_buf,
const IndexC& c_origin) const
{
constexpr auto m_n_block_length =
ck::Sequence<ABlockSliceLengths::At(AThreadLoopOverDim),
BBlockSliceLengths::At(BThreadLoopOverDim)>{};
constexpr auto m_n_thread_length =
ck::Sequence<AThreadSliceLength::At(AThreadLoopOverDim),
BThreadSliceLength::At(BThreadLoopOverDim)>{};
constexpr auto m_n_access_length = m_n_block_length / m_n_thread_length; const CDesc& c_desc,
CBuffer& c_buf,
const IndexC& /* c_origin */,
constexpr auto ordered_m_n_access_length = bool is_accumulate_c = true) const
container_reorder_given_new2old(m_n_access_length, ThreadMNAccessOrder{}); {
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
constexpr auto a_block_idx_zeros = // printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
typename uniform_sequence_gen<nDimA, 0>::type{}; // starting point of the block
constexpr auto b_block_idx_zeros = typename uniform_sequence_gen<nDimB, 0>::type{};
constexpr auto lda = GetLeadingElement(a_block_desc) * sizeof(FloatA); const auto k_per_block = GetKPerBlock(a_block_desc);
constexpr auto ldb = GetLeadingElement(b_block_desc) * sizeof(FloatB); const auto m_per_block = GetMPerBlock(a_block_desc);
constexpr auto ldc = GetLeadingElement(c_block_desc) * sizeof(FloatC); const auto n_per_block = GetNPerBlock(b_block_desc);
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
param.Kr = KPerBlock; param.Kr = k_per_block;
param.lda = lda; param.lda = lda;
param.ldb = ldb; param.ldb = ldb;
param.ldc = ldc; param.ldc = ldc;
param.alpha = 1.0f; // TODO param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0;
static_ford<decltype(ordered_m_n_access_length)>{}([&](auto ordered_idx) { if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
constexpr auto origin_m_n_idx = ordered_idx.ReorderGivenOld2New(ThreadMNAccessOrder{}); {
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
constexpr auto current_m_idx = {
origin_m_n_idx.At(0) * AThreadSliceLength::At(AThreadLoopOverDim); auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
constexpr auto current_n_idx = param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
origin_m_n_idx.At(1) * BThreadSliceLength::At(BThreadLoopOverDim);
constexpr auto current_mr =
ck::math::min(m_n_block_length.At(0) - current_m_idx, m_n_thread_length.At(0));
constexpr auto current_nr =
ck::math::min(m_n_block_length.At(1) - current_n_idx, m_n_thread_length.At(1));
constexpr auto a_block_idx = // printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
a_block_idx_zeros.Modify(AThreadLoopOverDim, current_m_idx); // GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr auto a_block_coord =
make_tensor_coordinate(a_block_desc, to_multi_index(a_origin + a_block_idx));
constexpr auto b_block_idx = for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
b_block_idx_zeros.Modify(BThreadLoopOverDim, current_n_idx); {
constexpr auto b_block_coord = auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
make_tensor_coordinate(b_block_desc, to_multi_index(b_origin + b_block_idx));
constexpr auto c_block_coord = param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
make_tensor_coordinate(c_block_desc, to_multi_index(c_origin + origin_m_n_idx)); param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
param.p_a = &a_block_buf.p_data_[a_block_coord.GetOffset()]; // printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
param.p_b = &b_block_buf.p_data_[b_block_coord.GetOffset()]; // current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
param.p_c = &c_block_buf.p_data_[c_block_coord.GetOffset()]; // GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr); ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
}); }
}
}
} }
}; };
......
...@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t ...@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC, OddC,
}; };
enum ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop,
NHWC_GemmKLoopOverC, // not merge c*y*x, and c % k_per_block == 0
};
enum ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver,
LoopOver_MNK,
LoopOver_MKN,
};
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp" #include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -23,20 +25,21 @@ namespace device { ...@@ -23,20 +25,21 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3) ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch> ck::index_t MPerThread,
ck::index_t NPerThread,
// bool IsGemmMPadded, bool UseALocalBuffer,
// bool IsGemmNPadded, bool UseBLocalBuffer,
// bool IsGemmKPadded> bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{ {
...@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
static constexpr auto GetWeightBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
static constexpr auto GetOutputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto wei_gemm_n_k_grid_desc = const auto wei_gemm_n_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor( const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
wei_gemm_n_k_grid_desc, wei_gemm_n_k_grid_desc,
ck::make_tuple(ck::make_unmerge_transform( make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
ck::make_tuple(wei_gemm_n_k_grid_desc.GetLength(I0) / make_pass_through_transform(gemm_k)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
wei_gemm_padn_k_grid_desc,
ck::make_tuple(
ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)), ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_n_k_grid_desc.GetLength(I1))), ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
...@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::multiplies<ck::index_t>()); std::multiplies<ck::index_t>());
} }
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N, static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using namespace ck; using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths); const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = K; const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths); const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A: // A:
...@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr bool UseCLocalBuffer = true; // static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType, WeiDataType, // WeiDataType,
OutDataType, // OutDataType, OutDataType, // OutDataType,
AccDataType, // AccDataType,
AGridDesc, // AGridDesc, AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc, BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc, CGridDesc, // CGridDesc,
...@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock, // NPerBlock, NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock, KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ck::Sequence<0, 1, 2>, // BlockMNKAccessOrder, AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
...@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm, const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType, InDataType,
WeiDataType, WeiDataType,
...@@ -591,7 +710,10 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -591,7 +710,10 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
float ave_time = launch_and_time_cpu_kernel(kernel, float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat, nrepeat,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
...@@ -605,7 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -605,7 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the // TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result // result
memset(arg.p_c_grid_, 0, arg.a_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
...@@ -749,15 +878,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -749,15 +878,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off // clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" << "DFwdAvx2_NHWC_KYXC"
<< "<" <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<< MPerBlock << ", " <<"_KS"<< static_cast<int>(GemmKSpecialization)
<< NPerBlock << ", " <<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< KPerBlock << "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< ">"; << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp" #include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp" #include "math.hpp"
#include "threadwise_param.hpp" #include "threadwise_gemm_param.hpp"
namespace ck { namespace ck {
namespace cpu { namespace cpu {
...@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n" ".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n" ".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_6x16_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n" " vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n" ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n" ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
...@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_6x16_Store_C%=:\n"
".if m_NTStore == 0\n" ".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
...@@ -424,6 +428,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -424,6 +428,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}; };
// clang-format off // clang-format off
if(param->accmulate_c){
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8); ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8); if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8); if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
...@@ -436,6 +441,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -436,6 +441,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8); if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8); if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8); if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 2 ) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 4 ) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){ while (Kr > 4){
#pragma unroll #pragma unroll
...@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
if constexpr (NonTemporalStore) { if constexpr (NonTemporalStore) {
_mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm1);
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1); if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2); if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3); if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
...@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n" ".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n" ".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n" "mov 60(%[m_param]), %%edi\n" // accmulate_c
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n" "test %%edi, %%edi\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n" "je L_GemmAvx2_MxN_4x24_Store_C%=\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n" " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n" ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n" ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n" ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n" ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n" ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n" ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_4x24_Store_C%=:\n"
".if m_NTStore == 0\n" ".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
...@@ -960,6 +984,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -960,6 +984,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}; };
// clang-format off // clang-format off
if(param->accmulate_c) {
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8); ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8); if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8); if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
...@@ -972,6 +997,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -972,6 +997,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8); if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8); if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8); if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 ) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 2 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 3 ) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){ while (Kr > 4){
#pragma unroll #pragma unroll
...@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch ...@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = { static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = {
{ {
ThreadwiseGemm_6x16_t::Run, ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_6x8_t::Run, ThreadwiseGemm_1x16_t::Run,
}, },
{ {
ThreadwiseGemm_5x16_t::Run, ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_5x8_t::Run, ThreadwiseGemm_2x16_t::Run,
}, },
{ {
ThreadwiseGemm_4x16_t::Run, ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_4x8_t::Run, ThreadwiseGemm_3x16_t::Run,
}, },
{ {
ThreadwiseGemm_3x16_t::Run, ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_3x8_t::Run, ThreadwiseGemm_4x16_t::Run,
}, },
{ {
ThreadwiseGemm_2x16_t::Run, ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_2x8_t::Run, ThreadwiseGemm_5x16_t::Run,
}, },
{ {
ThreadwiseGemm_1x16_t::Run, ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_1x8_t::Run, ThreadwiseGemm_6x16_t::Run,
}, },
}; };
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr) static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{ {
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 5 && in >= 0 && in <= 1);
return dispatch_table[mr][nr](param); return dispatch_table[mr][nr](param);
} }
}; };
...@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch ...@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = { static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = {
{ {
ThreadwiseGemm_4x24_t::Run, ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_4x16_t::Run, ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_4x8_t::Run, ThreadwiseGemm_1x24_t::Run,
}, },
{ {
ThreadwiseGemm_3x24_t::Run, ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_3x16_t::Run, ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_3x8_t::Run, ThreadwiseGemm_2x24_t::Run,
}, },
{ {
ThreadwiseGemm_2x24_t::Run, ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_2x16_t::Run, ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_2x8_t::Run, ThreadwiseGemm_3x24_t::Run,
}, },
{ {
ThreadwiseGemm_1x24_t::Run, ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_1x16_t::Run, ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_1x8_t::Run, ThreadwiseGemm_4x24_t::Run,
}, },
}; };
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr) static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{ {
return dispatch_table[mr][nr](param); index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 3 && in >= 0 && in <= 2);
return dispatch_table[im][in](param);
} }
}; };
......
#ifndef CK_THREADWISE_PARAM_HPP #ifndef CK_THREADWISE_GEMM_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP #define CK_THREADWISE_GEMM_PARAM_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam ...@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t ldb; // in unit of byte uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte uint64_t ldc; // in unit of byte
float alpha; float alpha;
uint32_t _pack0; int accmulate_c; // if 1, need load C and add into current fma. if 0, direct store out c result
} __attribute__((packed)); } __attribute__((packed));
} // namespace cpu } // namespace cpu
......
...@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2 ...@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{ {
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0, static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide"); "wrong! cannot evenly divide");
int N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
int Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
int Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
int C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
int Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
int Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
int Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
int Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
int Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
int Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
int Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
int Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
int Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
int Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
} }
void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2 ...@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl; // std::cout<<"num_access:"<<num_access<<std::endl;
std::cout << "src hidden:" << SrcDesc::GetNumOfHiddenDimension() << std::endl;
std::cout << "dst hidden:" << DstDesc::GetNumOfHiddenDimension() << std::endl;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) { static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>; using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2 ...@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
}); });
#endif
const auto src_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto src_slice_step = make_tensor_coordinate_step(
src_desc, to_multi_index(src_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
const auto dst_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto dst_slice_step = make_tensor_coordinate_step(
dst_desc, to_multi_index(dst_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
for(auto idx_id = 0; idx_id < num_access; idx_id++)
{
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = ck::cpu::vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
printf("[%s] ", is_src_valid ? "y" : "n");
print_multi_index(src_coord_.GetIndex());
printf("----");
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_(dst_vector_container.template AsType<dst_vector_t>(),
src_vector_container.template AsType<src_vector_t>());
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
printf(" -> ");
print_multi_index(dst_coord_.GetIndex());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf("\n");
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>());
// move coordinate
if(idx_id != num_access - 1)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(src_desc, src_coord_, src_slice_step);
move_tensor_coordinate(dst_desc, dst_coord_, dst_slice_step);
}
}
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
......
...@@ -121,6 +121,10 @@ template <typename... Args, typename F> ...@@ -121,6 +121,10 @@ template <typename... Args, typename F>
float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args) float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
{ {
WallTimer timer; WallTimer timer;
int nwarmup = 3;
for(int i = 0; i < nwarmup; i++)
kernel(args...); kernel(args...);
timer.Start(); timer.Start();
......
...@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; / ...@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch = using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType, ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType, WeiType,
...@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 = ...@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch // clang-format off
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false),
float, DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
float, DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
float, // DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
float, DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>;
PassThrough, // clang-format on
PassThrough,
PassThrough,
ConvFwdDefault,
2,
256,
128,
64,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
512,
256,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
1024,
144,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>>;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
......
...@@ -9,8 +9,11 @@ ...@@ -9,8 +9,11 @@
#include "reference_conv_fwd.hpp" #include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32 #define AVX2_DATA_ALIGNMENT 32
using F32 = float;
using F16 = ck::half_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,6 +21,8 @@ namespace cpu { ...@@ -18,6 +21,8 @@ namespace cpu {
namespace device { namespace device {
namespace device_conv2d_fwd_avx2_instance { namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
...@@ -34,6 +39,7 @@ using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; ...@@ -34,6 +39,7 @@ using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template <typename T> template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{ {
int error_count = 0;
float max_diff = 1e-6; float max_diff = 1e-6;
for(int i = 0; i < ref.mData.size(); ++i) for(int i = 0; i < ref.mData.size(); ++i)
...@@ -41,28 +47,35 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -41,28 +47,35 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff) if(max_diff < diff)
{ {
return false; error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
i,
double(ref.mData[i]),
double(result.mData[i]),
diff);
} }
} }
return true; return error_count == 0;
} }
float calculate_gflops() {}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
int data_type = 0; int data_type = 0;
int init_method = 0; int init_method = 0;
// Conv shape // Conv shape
ck::index_t N = 128; ck::index_t N = 2;
ck::index_t K = 256; ck::index_t K = 256;
ck::index_t C = 192; ck::index_t C = 192;
ck::index_t Y = 3; ck::index_t Y = 3;
ck::index_t X = 3; ck::index_t X = 3;
ck::index_t Hi = 71; ck::index_t Hi = 71;
ck::index_t Wi = 71; ck::index_t Wi = 71;
ck::index_t conv_stride_h = 2; ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 2; ck::index_t conv_stride_w = 1;
ck::index_t conv_dilation_h = 1; ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1; ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1; ck::index_t in_left_pad_h = 1;
...@@ -72,7 +85,7 @@ int main(int argc, char* argv[]) ...@@ -72,7 +85,7 @@ int main(int argc, char* argv[])
if(argc == 1) if(argc == 1)
{ {
data_type = 1; data_type = 0;
init_method = 1; init_method = 1;
} }
else if(argc == 3) else if(argc == 3)
...@@ -110,17 +123,14 @@ int main(int argc, char* argv[]) ...@@ -110,17 +123,14 @@ int main(int argc, char* argv[])
exit(1); exit(1);
} }
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type); using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type); using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type); using OutDataType = decltype(out_type);
using AccDataType = decltype(acc_type);
using ReferenceConvBwdInstance = using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>; OutElementOp>;
...@@ -147,53 +157,93 @@ int main(int argc, char* argv[]) ...@@ -147,53 +157,93 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_})); std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
}; };
Tensor<OutDataType> out_n_ho_wo_k(f_host_tensor_descriptor(N, K, Ho, Wo)); Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_y_x_c(f_host_tensor_descriptor(K, C, Y, X)); Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
Tensor<InDataType> in_n_hi_wi_c_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<InDataType> in_n_hi_wi_c_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
std::cout << "in (N, C, Hi, Wi): " << in_n_hi_wi_c_host_result.mDesc << std::endl; std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_y_x_c.mDesc << std::endl; std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_ho_wo_k.mDesc << std::endl; std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break; break;
case 2: case 2:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break;
case 3:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for(auto i_n = 0; i_n < N; i_n++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_hi = 0; i_hi < Hi; i_hi++)
{
for(auto i_wi = 0; i_wi < Wi; i_wi++)
{
uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
}
}
}
}
for(auto i_k = 0; i_k < K; i_k++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_y = 0; i_y < Y; i_y++)
{
for(auto i_x = 0; i_x < X; i_x++)
{
uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
}
}
}
}
break; break;
default: default:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
} }
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
in_n_hi_wi_c_device_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT); AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf( DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT); sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf( DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT); out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero
in_n_hi_wi_c_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_hi_wi_c_device_result.mData.data());
// get host result // get host result
{ {
auto ref_conv = ReferenceConvFwdInstance{}; auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c_host_result, auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_y_x_c, wei_k_c_y_x,
out_n_ho_wo_k, out_n_k_ho_wo_host_result,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -205,8 +255,8 @@ int main(int argc, char* argv[]) ...@@ -205,8 +255,8 @@ int main(int argc, char* argv[])
} }
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr = using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>; DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
...@@ -226,6 +276,9 @@ int main(int argc, char* argv[]) ...@@ -226,6 +276,9 @@ int main(int argc, char* argv[])
// profile device Conv instances // profile device Conv instances
bool success = true; bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs) for(auto& conv_ptr : conv_ptrs)
{ {
auto argument_ptr = conv_ptr->MakeArgumentPointer( auto argument_ptr = conv_ptr->MakeArgumentPointer(
...@@ -249,18 +302,30 @@ int main(int argc, char* argv[]) ...@@ -249,18 +302,30 @@ int main(int argc, char* argv[])
if(conv_ptr->IsSupportedArgument(argument_ptr.get())) if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
auto invoker_ptr = conv_ptr->MakeInvokerPointer(); auto invoker_ptr = conv_ptr->MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), 1); double time = invoker_ptr->Run(argument_ptr.get(), 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
in_device_buf.FromDevice(in_n_hi_wi_c_device_result.mData.data()); double gflops = (total_flop * 1e-6) / time;
if(!check_out(in_n_hi_wi_c_host_result, in_n_hi_wi_c_device_result)) out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result))
{ {
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false; success = false;
} }
else else
{ {
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
} }
} }
else else
...@@ -269,25 +334,27 @@ int main(int argc, char* argv[]) ...@@ -269,25 +334,27 @@ int main(int argc, char* argv[])
} }
} }
if(success) if(fastest_kernel_time != std::numeric_limits<double>::max())
{ {
std::cout << "test conv2d fwd cpu : Pass" << std::endl; std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
return 0; << "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
else
{
std::cout << "test conv2d fwd cpu: Fail " << std::endl;
return -1;
} }
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
}; };
if(data_type == 0) if(data_type == 0)
{ {
return Run(F32(), F32(), F32(), F32()); return Run(F32(), F32(), F32());
}
else if(data_type == 1)
{
return Run(F16(), F16(), F16(), F32());
} }
else else
{ {
......
...@@ -226,6 +226,8 @@ int main(int argc, char** argv) ...@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static constexpr ck::index_t nDim = static constexpr ck::index_t nDim =
ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension(); ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension();
input_desc.Print();
auto threadwise_transfer = threadwise_transfer_t{input_desc, auto threadwise_transfer = threadwise_transfer_t{input_desc,
ck::make_zero_multi_index<nDim>(), ck::make_zero_multi_index<nDim>(),
input_cblock_desc, input_cblock_desc,
......
...@@ -321,6 +321,7 @@ void test_ukernel(ukenrel_t uk, ...@@ -321,6 +321,7 @@ void test_ukernel(ukenrel_t uk,
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB); param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float); param.ldc = n * sizeof(float);
param.alpha = alpha; param.alpha = alpha;
param.accmulate_c = 0;
memset(private_c, 0, m * n * sizeof(float)); memset(private_c, 0, m * n * sizeof(float));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment