Commit afc7d431 authored by carlushuang's avatar carlushuang
Browse files

avx2 gemm now works for single thread

parent 07af8343
...@@ -13,21 +13,10 @@ namespace cpu { ...@@ -13,21 +13,10 @@ namespace cpu {
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AccDataType,
typename ABlockDesc, typename ABlockDesc,
typename BBlockDesc, typename BBlockDesc,
typename CBlockDesc, typename CDesc,
typename ABlockSliceLengths,
typename BBlockSliceLengths,
typename CBlockSliceLengths,
typename AThreadSliceLength,
typename BThreadSliceLength,
ck::index_t AThreadLoopOverDim, // thread slice loop over on block slice. 1d is enough for
// now
ck::index_t BThreadLoopOverDim,
ck::index_t KPerBlock, ck::index_t KPerBlock,
...@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN ...@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension(); static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension();
static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension(); static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CBlockDesc::GetNumOfDimension(); static constexpr index_t nDimC = CDesc::GetNumOfDimension();
using IndexA = MultiIndex<nDimA>; using IndexA = MultiIndex<nDimA>;
using IndexB = MultiIndex<nDimB>; using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>; using IndexC = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{})); using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{})); using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CBlockDesc{}, IndexC{})); using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
template <typename TensorDesc> template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc) constexpr auto GetLeadingElement(const TensorDesc& desc)
...@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN ...@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
} }
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CBlockBuffer> ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const
{
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t GetCLeadingElement(const CDesc& c_desc) const
{
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
}
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
return i_m;
}
}
ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return i_n;
}
else
{
// N/8 * K * 8
return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const
{
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc, void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf, const ABlockBuffer& a_block_buf,
const IndexA& a_origin, const IndexA& /* a_origin */,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
const IndexB& b_origin, const IndexB& /* b_origin */,
const CDesc& c_desc,
CBuffer& c_buf,
const IndexC& /* c_origin */,
const CBlockDesc& c_block_desc, bool is_accumulate_c = true) const
CBlockBuffer& c_block_buf,
const IndexC& c_origin) const
{ {
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
constexpr auto m_n_block_length = // printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
ck::Sequence<ABlockSliceLengths::At(AThreadLoopOverDim),
BBlockSliceLengths::At(BThreadLoopOverDim)>{};
constexpr auto m_n_thread_length =
ck::Sequence<AThreadSliceLength::At(AThreadLoopOverDim),
BThreadSliceLength::At(BThreadLoopOverDim)>{};
constexpr auto m_n_access_length = m_n_block_length / m_n_thread_length; const auto k_per_block = GetKPerBlock(a_block_desc);
const auto m_per_block = GetMPerBlock(a_block_desc);
const auto n_per_block = GetNPerBlock(b_block_desc);
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
constexpr auto ordered_m_n_access_length = ck::cpu::ThreadwiseGemmParam param;
container_reorder_given_new2old(m_n_access_length, ThreadMNAccessOrder{}); param.Kr = k_per_block;
param.lda = lda;
param.ldb = ldb;
param.ldc = ldc;
param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0;
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
{
auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
constexpr auto a_block_idx_zeros = // printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
typename uniform_sequence_gen<nDimA, 0>::type{}; // starting point of the block // GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr auto b_block_idx_zeros = typename uniform_sequence_gen<nDimB, 0>::type{};
constexpr auto lda = GetLeadingElement(a_block_desc) * sizeof(FloatA); for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
constexpr auto ldb = GetLeadingElement(b_block_desc) * sizeof(FloatB); {
constexpr auto ldc = GetLeadingElement(c_block_desc) * sizeof(FloatC); auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
ck::cpu::ThreadwiseGemmParam param; param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
param.Kr = KPerBlock; param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
param.lda = lda;
param.ldb = ldb; // printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
param.ldc = ldc; // current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
param.alpha = 1.0f; // TODO // GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
static_ford<decltype(ordered_m_n_access_length)>{}([&](auto ordered_idx) {
constexpr auto origin_m_n_idx = ordered_idx.ReorderGivenOld2New(ThreadMNAccessOrder{}); ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
}
constexpr auto current_m_idx = }
origin_m_n_idx.At(0) * AThreadSliceLength::At(AThreadLoopOverDim); }
constexpr auto current_n_idx =
origin_m_n_idx.At(1) * BThreadSliceLength::At(BThreadLoopOverDim);
constexpr auto current_mr =
ck::math::min(m_n_block_length.At(0) - current_m_idx, m_n_thread_length.At(0));
constexpr auto current_nr =
ck::math::min(m_n_block_length.At(1) - current_n_idx, m_n_thread_length.At(1));
constexpr auto a_block_idx =
a_block_idx_zeros.Modify(AThreadLoopOverDim, current_m_idx);
constexpr auto a_block_coord =
make_tensor_coordinate(a_block_desc, to_multi_index(a_origin + a_block_idx));
constexpr auto b_block_idx =
b_block_idx_zeros.Modify(BThreadLoopOverDim, current_n_idx);
constexpr auto b_block_coord =
make_tensor_coordinate(b_block_desc, to_multi_index(b_origin + b_block_idx));
constexpr auto c_block_coord =
make_tensor_coordinate(c_block_desc, to_multi_index(c_origin + origin_m_n_idx));
param.p_a = &a_block_buf.p_data_[a_block_coord.GetOffset()];
param.p_b = &b_block_buf.p_data_[b_block_coord.GetOffset()];
param.p_c = &c_block_buf.p_data_[c_block_coord.GetOffset()];
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
});
} }
}; };
......
...@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t ...@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC, OddC,
}; };
enum ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop,
NHWC_GemmKLoopOverC, // not merge c*y*x, and c % k_per_block == 0
};
enum ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver,
LoopOver_MNK,
LoopOver_MKN,
};
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp" #include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -23,20 +25,21 @@ namespace device { ...@@ -23,20 +25,21 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial, ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3) ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch> ck::index_t MPerThread,
ck::index_t NPerThread,
// bool IsGemmMPadded, bool UseALocalBuffer,
// bool IsGemmNPadded, bool UseBLocalBuffer,
// bool IsGemmKPadded> bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{ {
...@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
static constexpr auto GetWeightBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
static constexpr auto GetOutputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n) static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{ {
ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto wei_gemm_n_k_grid_desc = const auto wei_gemm_n_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor( const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
wei_gemm_n_k_grid_desc, wei_gemm_n_k_grid_desc,
ck::make_tuple(ck::make_unmerge_transform( make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
ck::make_tuple(wei_gemm_n_k_grid_desc.GetLength(I0) / make_pass_through_transform(gemm_k)),
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize, ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
ck::make_pass_through_transform(wei_gemm_n_k_grid_desc.GetLength(I1))),
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
wei_gemm_padn_k_grid_desc,
ck::make_tuple(
ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
...@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::multiplies<ck::index_t>()); std::multiplies<ck::index_t>());
} }
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N, static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using namespace ck; using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths); const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = K; const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths); const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A: // A:
...@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr bool UseCLocalBuffer = true; // static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm = using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType, ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType, WeiDataType, // WeiDataType,
OutDataType, // OutDataType, OutDataType, // OutDataType,
AccDataType, // AccDataType,
AGridDesc, // AGridDesc, AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc, BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc, CGridDesc, // CGridDesc,
...@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock, // NPerBlock, NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock, KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ck::Sequence<0, 1, 2>, // BlockMNKAccessOrder, AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer UseCLocalBuffer // UseCLocalBuffer
>; >;
...@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
} }
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm, const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType, InDataType,
WeiDataType, WeiDataType,
...@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
float ave_time = launch_and_time_cpu_kernel(kernel, float ave_time = 0;
nrepeat,
arg.p_a_grid_, if(nrepeat != 1)
arg.p_b_grid_, ave_time = launch_and_time_cpu_kernel(kernel,
arg.p_c_grid_, nrepeat,
arg.a_grid_desc_, arg.p_a_grid_,
arg.b_grid_desc_, arg.p_b_grid_,
arg.c_grid_desc_, arg.p_c_grid_,
arg.a_element_op_, arg.a_grid_desc_,
arg.b_element_op_, arg.b_grid_desc_,
arg.c_element_op_); arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the // TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result // result
memset(arg.p_c_grid_, 0, arg.a_grid_desc_.GetElementSpaceSize()); memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel, launch_cpu_kernel(kernel,
arg.p_a_grid_, arg.p_a_grid_,
...@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} }
} }
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_); return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
} }
...@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off // clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial) str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" << "DFwdAvx2_NHWC_KYXC"
<< "<" <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<< MPerBlock << ", " <<"_KS"<< static_cast<int>(GemmKSpecialization)
<< NPerBlock << ", " <<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< KPerBlock << "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< ">"; << "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp" #include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp" #include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <unistd.h>
namespace ck { namespace ck {
namespace cpu { namespace cpu {
...@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid, ...@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AccDataType,
typename AGridDesc, typename AGridDesc,
typename BGridDesc, typename BGridDesc,
typename CGridDesc, typename CGridDesc,
...@@ -57,334 +58,92 @@ template <typename FloatA, ...@@ -57,334 +58,92 @@ template <typename FloatA,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch, typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy,
typename BThreadwiseCopy,
typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer. if false, will write to C directly // copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
> >
struct GridwiseGemmAvx2_MxN struct GridwiseGemmAvx2_MxN
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats // static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit static constexpr index_t MemAlignmentByte = 32; // 256bit
static constexpr auto GetABlockDescriptor() static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// A : M, K // A : M, K
constexpr auto a_block_desc_m_k = auto a_block_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock)); make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
return a_block_desc_m_k; return a_block_desc_m_k;
} }
else else
{ {
// A : K, M // A : K, M
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_packed( auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(KPerBlock, make_tuple(k_per_blk,
math::integer_least_multiple( math::integer_least_multiple(
MPerBlock, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize))); m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
return a_block_desc_k_m; return a_block_desc_k_m;
} }
} }
static constexpr auto GetBBlockDescriptor() static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{ {
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value) ck::tensor_layout::gemm::RowMajor>::value)
{ {
// B : K, N // B : K, N
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_packed( auto b_block_desc_k_n =
make_tuple(KPerBlock, make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
math::integer_least_multiple(
NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)));
return b_block_desc_k_n; return b_block_desc_k_n;
} }
else else
{ {
// B : N/8, K, N8 // B : N/8, K, N8
constexpr auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple( auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock, k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)); ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_block_desc_n0_k_n1; return b_block_desc_n0_k_n1;
} }
} }
static constexpr auto GetABlockSliceLength() static auto GetCBlockDescriptor(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::Sequence<MPerBlock, KPerBlock>{};
}
else
{
// A : K, M
return ck::Sequence<KPerBlock, MPerBlock>{};
}
}
static constexpr auto GetBBlockSliceLength()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::Sequence<KPerBlock, NPerBlock>{};
}
else
{
// B : N/8, K, N88;
return ck::Sequence<NPerBlock / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize>{};
}
}
static constexpr auto GetABlockDimAccessOrder() { return ck::Sequence<0, 1>{}; }
static constexpr auto GetBBlockDimAccessOrder()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::Sequence<0, 1>{};
}
else
{
// B : N/8, K, N88;
return ck::Sequence<0, 1, 2>{};
}
}
static constexpr auto GetABlockMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(0, KPerBlock);
}
else
{
// A : K, M
return ck::make_multi_index(KPerBlock, 0);
}
}
static constexpr auto GetBBlockMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(KPerBlock, 0);
}
else
{
// B : N/8, K, N88;
return ck::make_multi_index(0, KPerBlock, 0);
}
}
#if 0
static constexpr auto GetAThreadDiscriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, ck::tensor_layout::gemm::RowMajor>::value){
// A : M, K
constexpr auto a_thread_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock));
return a_thread_desc_m_k;
} else {
// A : K, M
constexpr auto a_thread_desc_k_m = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr));
return a_thread_desc_k_m;
}
}
static constexpr auto GetBThreadDescriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, ck::tensor_layout::gemm::RowMajor>::value){
// B : K, N
constexpr auto b_thread_desc_k_n = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr));
return b_thread_desc_k_n;
} else {
// B : N/8, K, N8
constexpr auto b_thread_desc_n_k_n8 = make_naive_tensor_descriptor_packed(make_tuple(math::integer_divide_ceil(ThreadwiseGemm_Dispatch::ThreadMaxNr, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_thread_desc_n_k_n8;
}
}
#endif
static constexpr auto GetAThreadSliceLength()
{ {
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::Sequence<ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock>{};
}
else
{
// A : K, M
return ck::Sequence<KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr>{};
}
}
static constexpr auto GetBThreadSliceLength()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::Sequence<KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr>{};
}
else
{
// B : N/8, K, N88;
return ck::Sequence<ThreadwiseGemm_Dispatch::ThreadMaxNr /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize>{};
}
} }
static constexpr auto GetAThreadMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(ThreadwiseGemm_Dispatch::ThreadMaxMr, 0);
}
else
{
// A : K, M
return ck::make_multi_index(0, ThreadwiseGemm_Dispatch::ThreadMaxMr);
}
}
static constexpr auto GetBThreadMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(0, ThreadwiseGemm_Dispatch::ThreadMaxNr);
}
else
{
// B : N/8, K, N88;
return ck::Sequence<ThreadwiseGemm_Dispatch::ThreadMaxNr /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
0,
0>{};
}
}
static constexpr ck::index_t GetAThreadLoopOverDim()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return 0;
}
else
{
// A : K, M
return 1;
}
}
static constexpr ck::index_t GetBThreadLoopOverDim()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return 1;
}
else
{
// B : N/8, K, N88;
return 0;
}
}
static constexpr auto GetCBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); // TODO:
}
}
static constexpr auto GetCBlockSliceLength() { return ck::Sequence<MPerBlock, NPerBlock>{}; }
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc) const CGridDesc& c_grid_desc)
{ {
#if 0 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); bool is_valid = true;
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto GemmN = c_grid_desc.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); if constexpr(UseCLocalBuffer)
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{ {
// 1-stage prefetch always supported if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
} is_valid &= false;
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
} }
else else
{ {
return false; // TODO: need check c grid is simple transform?
if(GemmN % 8 != 0)
is_valid &= false;
} }
return is_valid;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
} }
static void Run(const FloatA* __restrict__ p_a_grid, static void Run(const FloatA* __restrict__ p_a_grid,
...@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN ...@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
ck::index_t m_per_block; ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block; ck::index_t n_per_block = NPerBlock;
ck::index_t k_per_block; ck::index_t k_per_block = KPerBlock;
if constexpr(MPerBlock == 0 && NPerBlock == 0 && KPerBlock == 0) {} const auto GemmM = c_grid_desc.GetLength(I0);
else const auto GemmN = c_grid_desc.GetLength(I1);
{ const auto GemmK = a_grid_desc.GetLength(I1);
m_per_block = MPerBlock;
n_per_block = NPerBlock; constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension();
k_per_block = KPerBlock;
} constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
const auto M = a_grid_desc.GetLength(I0); auto a_threadwise_copy = AThreadwiseCopy(a_grid_desc,
const auto N = b_grid_desc.GetLength(I1); ck::make_zero_multi_index<a_block_copy_dim>(),
const auto K = b_grid_desc.GetLength(I0); GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
const ck::index_t grid_m = math::integer_divide_ceil(M, m_per_block); AElementwiseOperation{});
const ck::index_t grid_n = math::integer_divide_ceil(N, n_per_block);
auto b_threadwise_copy = BThreadwiseCopy(b_grid_desc,
const ck::index_t grid_size = grid_m * grid_n; ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block),
constexpr auto a_block_desc = GetABlockDescriptor(); ck::make_zero_multi_index<b_block_copy_dim>(),
constexpr auto a_block_slice_length = GetABlockSliceLength(); BElementwiseOperation{});
constexpr auto a_block_copy_dim = decltype(a_block_slice_length)::Size();
constexpr auto a_dim_access_order = GetABlockDimAccessOrder(); auto c_threadwise_copy = CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
constexpr auto a_block_move_step = GetABlockMoveFwdStep(); ck::make_zero_multi_index<2>(),
constexpr auto a_thread_slice_length = GetAThreadSliceLength(); c_grid_desc,
constexpr auto a_thread_loop_over_dim = GetAThreadLoopOverDim(); ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
constexpr auto b_block_desc = GetBBlockDescriptor();
constexpr auto b_block_slice_length = GetBBlockSliceLength(); DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
constexpr auto b_block_copy_dim = decltype(b_block_slice_length)::Size(); MemAlignmentByte);
constexpr auto b_dim_access_order = GetBBlockDimAccessOrder(); DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
constexpr auto b_block_move_step = GetBBlockMoveFwdStep(); MemAlignmentByte);
constexpr auto b_thread_slice_length = GetBThreadSliceLength(); DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
constexpr auto b_thread_loop_over_dim = GetBThreadLoopOverDim(); MemAlignmentByte);
constexpr auto c_block_desc = GetCBlockDescriptor(); auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
constexpr auto c_block_slice_length = GetCBlockSliceLength();
constexpr auto c_block_move_step = ck::make_multi_index(0, NPerBlock);
auto a_threadwise_copy = ck::cpu::ThreadwiseTensorSliceTransferAvx2<
FloatA, // SrcData
FloatA, // DstData
decltype(a_grid_desc), // SrcDesc
decltype(a_block_desc), // DstDesc
AElementwiseOperation, // ElementwiseOperation
decltype(a_block_slice_length), // SliceLengths
decltype(a_dim_access_order), // DimAccessOrder
1, // VectorDim
1, // ScalarPerVector
ck::InMemoryDataOperationEnum_t::Set, // InMemoryDataOperationEnum_t
false, // SrcResetCoordinateAfterRun
true // DstResetCoordinateAfterRun
>(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
a_block_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy = ck::cpu::ThreadwiseTensorSliceTransferAvx2<
FloatB, // SrcData
FloatB, // DstData
decltype(b_grid_desc), // SrcDesc
decltype(b_block_desc), // DstDesc
BElementwiseOperation, // ElementwiseOperation
decltype(b_block_slice_length), // SliceLengths
decltype(b_dim_access_order), // DimAccessOrder
1, // VectorDim
1, // ScalarPerVector
ck::InMemoryDataOperationEnum_t::Set, // InMemoryDataOperationEnum_t
false, // SrcResetCoordinateAfterRun
true // DstResetCoordinateAfterRun
>(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
b_block_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy = ck::cpu::ThreadwiseTensorSliceTransferAvx2<
FloatC, // SrcData
FloatC, // DstData
decltype(c_block_desc), // SrcDesc
decltype(c_grid_desc), // DstDesc
BElementwiseOperation, // ElementwiseOperation
ck::Sequence<MPerBlock, NPerBlock>, // SliceLengths
ck::Sequence<0, 1>, // DimAccessOrder
1, // VectorDim
1, // ScalarPerVector
ck::InMemoryDataOperationEnum_t::Set, // InMemoryDataOperationEnum_t
true, // SrcResetCoordinateAfterRun
false // DstResetCoordinateAfterRun
>(c_block_desc,
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(MPerBlock * KPerBlock * sizeof(FloatA), MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(KPerBlock * NPerBlock * sizeof(FloatB), MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(MPerBlock * NPerBlock * sizeof(FloatC), MemAlignmentByte);
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize()); reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>( auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize()); reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>( auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize()); reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>( auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf), reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
a_block_mem.mMemSize / sizeof(FloatA)); a_block_mem.mMemSize / sizeof(FloatA));
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>( auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf), reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
b_block_mem.mMemSize / sizeof(FloatB)); b_block_mem.mMemSize / sizeof(FloatB));
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>( auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf), UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
c_block_mem.mMemSize / sizeof(FloatC)); : reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
auto blockwise_gemm = : c_grid_desc.GetElementSpaceSize());
BlockwiseGemmAvx2_MxN<FloatA, // FloatA,
FloatB, // FloatB, auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatC, // FloatC, FloatA, // FloatA,
AccDataType, // AccDataType, FloatB, // FloatB,
decltype(a_block_desc), // ABlockDesc, FloatC, // FloatC,
decltype(b_block_desc), // BBlockDesc, decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(c_block_desc), // CBlockDesc, decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(a_block_slice_length), // ABlockSliceLengths, decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
decltype(b_block_slice_length), // BBlockSliceLengths, KPerBlock, // KPerBlock,
decltype(c_block_slice_length), // CBlockSliceLengths, ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
decltype(a_thread_slice_length), // AThreadSliceLength, ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
decltype(b_thread_slice_length), // BThreadSliceLength, // gemm MN to utilize micro kernel>{};
a_thread_loop_over_dim, // AThreadLoopOverDim, // thread slice
// loop over on block slice. 1d is enough
// for now
b_thread_loop_over_dim, // BThreadLoopOverDim,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering // TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value) if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{ {
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
const ck::index_t grid_size = grid_m * grid_n;
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel for #pragma omp parallel for
for(ck::index_t gid = 0; gid < grid_size; gid++) for(ck::index_t gid = 0; gid < grid_size; gid++)
{ {
ck::index_t i_mc = (gid / grid_n) * m_per_block; ck::index_t i_mc = (gid / grid_n) * m_per_block;
ck::index_t i_nc = (gid % grid_n) * n_per_block; ck::index_t i_nc = (gid % grid_n) * n_per_block;
ck::index_t mc_size = ck::math::min(M - i_mc, m_per_block); ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
ck::index_t nc_size = ck::math::min(N - i_nc, n_per_block); ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
// pack_b nc_size = math::integer_least_multiple(
b_threadwise_copy.RunGeneric(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_move_step);
if(i_nc == 0) a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(
b_grid_desc,
ck::make_multi_index(math::integer_divide_ceil(
i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0));
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
}
else
{ {
// pack_a c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
a_threadwise_copy.RunGeneric( ck::make_multi_index(i_mc, i_nc));
a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_move_step);
} }
for(ck::index_t i_kc = 0; i_kc < K; i_kc += k_per_block) for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{ {
ck::index_t kc_size = ck::math::min(K - i_kc, k_per_block); ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
// printf("A ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//}
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__);
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
make_zero_multi_index<a_block_copy_dim>(), make_zero_multi_index<a_block_copy_dim>(),
...@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN ...@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN
make_zero_multi_index<b_block_copy_dim>(), make_zero_multi_index<b_block_copy_dim>(),
c_block_desc, c_block_desc,
c_block_buf, c_block_buf,
make_zero_multi_index<2>()); make_zero_multi_index<2>(),
i_kc != 0);
// printf("[%d] 2222 \n",__LINE__);
if((i_kc + k_per_block) < GemmK)
{
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
// printf("[%d] 2222 \n",__LINE__);
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
} }
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
// i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(
math::integer_divide_ceil(n_per_block,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0);
// only parallel in gemm m dim
#pragma omp parallel for
for(ck::index_t i_mc = 0; i_mc < GemmM; i_mc += m_per_block)
{
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{ {
c_threadwise_copy.RunGeneric( ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.MoveDstSliceWindow(c_grid_desc, c_block_move_step); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0));
// TODO: if use local C buffer, then this nc loop need to loop only once
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
{
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
blockwise_gemm.Run(a_block_desc,
a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
b_block_desc,
b_block_buf,
make_zero_multi_index<b_block_copy_dim>(),
c_block_desc,
c_block_buf,
make_zero_multi_index<2>(),
i_kc != 0);
if((i_nc + n_per_block) < GemmN)
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
}
if((i_kc + k_per_block) < GemmK)
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
} }
} }
} }
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp" #include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp" #include "math.hpp"
#include "threadwise_param.hpp" #include "threadwise_gemm_param.hpp"
namespace ck { namespace ck {
namespace cpu { namespace cpu {
...@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n" ".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n" ".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_6x16_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n" " vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n" ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n" ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
...@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_6x16_Store_C%=:\n"
".if m_NTStore == 0\n" ".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
...@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}; };
// clang-format off // clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8); if(param->accmulate_c){
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8); ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8); if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8); if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8); if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8); if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8); if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8); if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8); if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8); if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8); if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8); if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 2 ) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 4 ) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){ while (Kr > 4){
#pragma unroll #pragma unroll
...@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
if constexpr (NonTemporalStore) { if constexpr (NonTemporalStore) {
_mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm1);
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1); if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2); if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3); if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
...@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n" ".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n" ".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n" "mov 60(%[m_param]), %%edi\n" // accmulate_c
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n" "test %%edi, %%edi\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n" "je L_GemmAvx2_MxN_4x24_Store_C%=\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n" " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n" ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n" ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n" ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n" ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n" ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n" ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_4x24_Store_C%=:\n"
".if m_NTStore == 0\n" ".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
...@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}; };
// clang-format off // clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8); if(param->accmulate_c) {
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8); ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8); if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8); if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8); if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8); if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8); if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8); if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8); if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8); if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8); if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8); if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 ) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 2 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 3 ) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){ while (Kr > 4){
#pragma unroll #pragma unroll
...@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch ...@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = { static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = {
{ {
ThreadwiseGemm_6x16_t::Run, ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_6x8_t::Run, ThreadwiseGemm_1x16_t::Run,
}, },
{ {
ThreadwiseGemm_5x16_t::Run, ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_5x8_t::Run, ThreadwiseGemm_2x16_t::Run,
}, },
{ {
ThreadwiseGemm_4x16_t::Run, ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_4x8_t::Run, ThreadwiseGemm_3x16_t::Run,
}, },
{ {
ThreadwiseGemm_3x16_t::Run, ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_3x8_t::Run, ThreadwiseGemm_4x16_t::Run,
}, },
{ {
ThreadwiseGemm_2x16_t::Run, ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_2x8_t::Run, ThreadwiseGemm_5x16_t::Run,
}, },
{ {
ThreadwiseGemm_1x16_t::Run, ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_1x8_t::Run, ThreadwiseGemm_6x16_t::Run,
}, },
}; };
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr) static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{ {
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 5 && in >= 0 && in <= 1);
return dispatch_table[mr][nr](param); return dispatch_table[mr][nr](param);
} }
}; };
...@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch ...@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = { static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = {
{ {
ThreadwiseGemm_4x24_t::Run, ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_4x16_t::Run, ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_4x8_t::Run, ThreadwiseGemm_1x24_t::Run,
}, },
{ {
ThreadwiseGemm_3x24_t::Run, ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_3x16_t::Run, ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_3x8_t::Run, ThreadwiseGemm_2x24_t::Run,
}, },
{ {
ThreadwiseGemm_2x24_t::Run, ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_2x16_t::Run, ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_2x8_t::Run, ThreadwiseGemm_3x24_t::Run,
}, },
{ {
ThreadwiseGemm_1x24_t::Run, ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_1x16_t::Run, ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_1x8_t::Run, ThreadwiseGemm_4x24_t::Run,
}, },
}; };
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr) static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{ {
return dispatch_table[mr][nr](param); index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 3 && in >= 0 && in <= 2);
return dispatch_table[im][in](param);
} }
}; };
......
#ifndef CK_THREADWISE_PARAM_HPP #ifndef CK_THREADWISE_GEMM_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP #define CK_THREADWISE_GEMM_PARAM_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam ...@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t ldb; // in unit of byte uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte uint64_t ldc; // in unit of byte
float alpha; float alpha;
uint32_t _pack0; int accmulate_c; // if 1, need load C and add into current fma. if 0, direct store out c result
} __attribute__((packed)); } __attribute__((packed));
} // namespace cpu } // namespace cpu
......
...@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2 ...@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{ {
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0, static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide"); "wrong! cannot evenly divide");
int N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
int Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
int Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
int C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
int Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
int Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
int Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
int Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
int Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
int Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
int Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
int Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
int Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
int Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
} }
void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2 ...@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl; // std::cout<<"num_access:"<<num_access<<std::endl;
std::cout << "src hidden:" << SrcDesc::GetNumOfHiddenDimension() << std::endl;
std::cout << "dst hidden:" << DstDesc::GetNumOfHiddenDimension() << std::endl;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) { static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>; using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2 ...@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
}); });
#endif
const auto src_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto src_slice_step = make_tensor_coordinate_step(
src_desc, to_multi_index(src_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
const auto dst_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto dst_slice_step = make_tensor_coordinate_step(
dst_desc, to_multi_index(dst_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
for(auto idx_id = 0; idx_id < num_access; idx_id++)
{
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = ck::cpu::vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
printf("[%s] ", is_src_valid ? "y" : "n");
print_multi_index(src_coord_.GetIndex());
printf("----");
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_(dst_vector_container.template AsType<dst_vector_t>(),
src_vector_container.template AsType<src_vector_t>());
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
printf(" -> ");
print_multi_index(dst_coord_.GetIndex());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf("\n");
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>());
// move coordinate
if(idx_id != num_access - 1)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(src_desc, src_coord_, src_slice_step);
move_tensor_coordinate(dst_desc, dst_coord_, dst_slice_step);
}
}
// move coordinate back to slice origin (or not) // move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
......
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp"
#include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
namespace ck {
namespace cpu {
namespace avx2_util {
inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0, _mm256_loadu_ps(p_src + 0));
_mm256_storeu_ps(p_dst + 8, _mm256_loadu_ps(p_src + 8));
p_dst += 16;
p_src += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst, _mm256_loadu_ps(p_src));
p_dst += 8;
p_src += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, _mm_loadu_ps(p_src));
p_dst += 4;
p_src += 4;
}
if(i_n & 2)
{
_mm_storeu_si64(p_dst, _mm_loadu_si64(p_src));
p_dst += 2;
p_src += 2;
}
if(i_n & 1)
{
*p_dst = *p_src;
}
}
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
__m256 ymm = _mm256_set1_ps(*reinterpret_cast<const float*>(&value));
__m128 xmm = _mm_set1_ps(*reinterpret_cast<const float*>(&value));
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0, ymm);
_mm256_storeu_ps(p_dst + 8, ymm);
p_dst += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst, ymm);
p_dst += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, xmm);
p_dst += 4;
}
if(i_n & 2)
{
_mm_storeu_si64(p_dst, xmm);
p_dst += 2;
}
if(i_n & 1)
{
*p_dst = *reinterpret_cast<const float*>(&value);
}
}
inline void
transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_t stride_src)
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256 r0, r1, r2, r3, r4, r5, r6, r7;
__m256 t0, t1, t2, t3, t4, t5, t6, t7;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
r0 = _mm256_loadu_ps(p_src + 0 * stride_src);
r1 = _mm256_loadu_ps(p_src + 1 * stride_src);
r2 = _mm256_loadu_ps(p_src + 2 * stride_src);
r3 = _mm256_loadu_ps(p_src + 3 * stride_src);
r4 = _mm256_loadu_ps(p_src + 4 * stride_src);
r5 = _mm256_loadu_ps(p_src + 5 * stride_src);
r6 = _mm256_loadu_ps(p_src + 6 * stride_src);
r7 = _mm256_loadu_ps(p_src + 7 * stride_src);
t0 = _mm256_unpacklo_ps(r0, r1);
t1 = _mm256_unpackhi_ps(r0, r1);
t2 = _mm256_unpacklo_ps(r2, r3);
t3 = _mm256_unpackhi_ps(r2, r3);
t4 = _mm256_unpacklo_ps(r4, r5);
t5 = _mm256_unpackhi_ps(r4, r5);
t6 = _mm256_unpacklo_ps(r6, r7);
t7 = _mm256_unpackhi_ps(r6, r7);
r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0));
r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2));
r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0));
r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2));
r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0));
r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2));
r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0));
r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2));
t0 = _mm256_permute2f128_ps(r0, r4, 0x20);
t1 = _mm256_permute2f128_ps(r1, r5, 0x20);
t2 = _mm256_permute2f128_ps(r2, r6, 0x20);
t3 = _mm256_permute2f128_ps(r3, r7, 0x20);
t4 = _mm256_permute2f128_ps(r0, r4, 0x31);
t5 = _mm256_permute2f128_ps(r1, r5, 0x31);
t6 = _mm256_permute2f128_ps(r2, r6, 0x31);
t7 = _mm256_permute2f128_ps(r3, r7, 0x31);
_mm256_storeu_ps(p_dst + 0 * stride_dst, t0);
_mm256_storeu_ps(p_dst + 1 * stride_dst, t1);
_mm256_storeu_ps(p_dst + 2 * stride_dst, t2);
_mm256_storeu_ps(p_dst + 3 * stride_dst, t3);
_mm256_storeu_ps(p_dst + 4 * stride_dst, t4);
_mm256_storeu_ps(p_dst + 5 * stride_dst, t5);
_mm256_storeu_ps(p_dst + 6 * stride_dst, t6);
_mm256_storeu_ps(p_dst + 7 * stride_dst, t7);
}
} // namespace avx2_util
using ConvolutionForwardSpecialization_t =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t;
using ConvolutionForwardGemmKSpecialization_t =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t;
// assume input -> a matrix
// assume input -> MC * KC
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC(
const SrcDesc& src_desc,
const Index&,
const DstDesc&,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
N = 1;
Hi = 1;
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; // gemm_m
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; // gemm_k
Ho = 1;
Wo = Wi;
Fy = 1;
Fx = 1;
Dy = 1;
Sy = 1;
Dx = 1;
Sx = 1;
Py = 0;
Px = 0;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
Ho = src_desc.GetTransforms()[Number<2>{}].GetUpperLengths()[Number<0>{}];
Wo = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<0>{}];
Fy = 1;
Fx = 1;
Dy = 1;
Sy = src_desc.GetTransforms()[Number<2>{}].coefficients_[Number<0>{}];
Dx = 1;
Sx = src_desc.GetTransforms()[Number<3>{}].coefficients_[Number<0>{}];
Py = 0;
Px = 0;
}
else
{
N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
}
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi = Sx * C;
input_offset_ovf_wi_acc_hi = Sy * Wi * C - Wo * Sx * C;
input_offset_ovf_hi_acc_n = Hi * Wi * C - Ho * Sy * Wi * C;
// input_offset_acc_c = 1;
input_offset_ovf_c_acc_x = Dx * C - C;
input_offset_ovf_x_acc_y = Dy * Wi * C - Fx * Dx * C;
src_offset = -Py * Wi * C - Px * C;
i_n = 0;
i_c = 0;
i_hi = -Py;
i_wi = -Px;
i_ho = 0;
i_wo = 0;
i_y = 0;
i_x = 0;
i_gemm_k = 0;
#if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
#endif
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_m = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
i_wi = idx_m;
i_c = idx_k;
src_offset = i_wi * C + i_c;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
i_wo = idx_m % Wo;
i_ho = (idx_m / Wo) % Ho;
i_n = (idx_m / Wo) / Ho;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c = idx_k;
i_x = 0;
i_y = 0;
i_hi = i_ho * Sy;
i_wi = i_wo * Sx;
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k;
}
else
{
i_wo = idx_m % Wo;
i_ho = (idx_m / Wo) % Ho;
i_n = (idx_m / Wo) / Ho;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if(idx_k == 0)
{
i_c = 0;
i_x = 0;
i_y = 0;
i_hi = i_ho * Sy - Py;
i_wi = i_wo * Sx - Px;
}
else
{
i_c = idx_k % C;
i_x = (idx_k / C) % Fx;
i_y = (idx_k / C) / Fx;
i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px;
}
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
}
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if constexpr(BypassTransfer)
{
float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
dst_buf.p_data_ = p_src;
}
else
{
const ck::index_t m_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block);
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
ck::index_t i_m_itr = m_per_block;
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block);
i_m_itr -= 8;
p_dst += 8 * k_per_block;
p_src += 8 * C;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
p_dst += 4 * k_per_block;
p_src += 4 * C;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
p_dst += 2 * k_per_block;
p_src += 2 * C;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho;
while(i_m_itr > 0)
{
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
p_dst += k_per_block;
i_wo_itr++;
p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo)
{
i_wo_itr = 0;
i_ho_itr++;
p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{
i_ho_itr = 0;
// i_n++;
p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
}
}
else
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
// c % k_per_block == 0, so every time k_per_block here is the same
ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho;
ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while(i_m_itr > 0)
{
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
else
avx2_util::memset32_avx2(p_dst, 0, k_per_block);
p_dst += k_per_block;
i_wo_itr++;
i_wi_itr += Sx;
p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo)
{
i_wo_itr = 0;
i_wi_itr -= Wo * Sx;
i_ho_itr++;
i_hi_itr += Sy;
p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{
i_ho_itr = 0;
i_hi_itr -= Ho * Sy;
// i_n++;
p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
}
// printf("[%d] \n", __LINE__);
}
else
{
ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho;
ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while(i_m_itr > 0)
{
/*** go along Gemm K ***/
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
ck::index_t i_wi_itr_k = i_wi_itr;
ck::index_t i_hi_itr_k = i_hi_itr;
ck::index_t i_c_itr_k = i_c;
ck::index_t i_y_itr_k = i_y;
ck::index_t i_x_itr_k = i_x;
ck::index_t i_k_itr = k_per_block;
while(i_k_itr > 0)
{
ck::index_t current_k_block = ck::math::min(C - i_c_itr_k, k_per_block);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2(p_dst_k, p_src_k, current_k_block);
else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block);
p_dst_k += current_k_block;
p_src_k += current_k_block;
i_c_itr_k += current_k_block;
if(i_c_itr_k >= C)
{
i_c_itr_k = 0;
i_x_itr_k++;
i_wi_itr_k += Dx;
p_src_k += input_offset_ovf_c_acc_x;
}
if(i_x_itr_k >= Fx)
{
i_x_itr_k = 0;
i_y_itr_k++;
i_hi_itr_k += Dy;
p_src_k += input_offset_ovf_x_acc_y;
}
i_k_itr -= current_k_block;
}
/*** go along Gemm K ***/
p_dst += k_per_block;
i_wo_itr++;
i_wi_itr += Sx;
p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo)
{
i_wo_itr = 0;
i_wi_itr -= Wo * Sx;
i_ho_itr++;
i_hi_itr += Sy;
p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{
i_ho_itr = 0;
i_hi_itr -= Ho * Sy;
// i_n++;
p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
}
}
}
}
}
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
i_c += move_k;
src_offset += move_k;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
i_c += move_k;
src_offset += move_k;
}
else
{
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
// c % k_per_block == 0, so every time k_per_block here is the same
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
i_c += move_k;
src_offset += move_k;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if(i_c >= C)
{
i_c = 0;
i_x++;
i_wi += Dx;
src_offset += Dx * C - C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
if(i_x >= Fx)
{
i_x = 0;
i_y++;
i_wi = i_wi - Fx * Dx;
i_hi += Dy;
src_offset += Dy * Wi * C - Fx * Dx * C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
}
else
{
i_gemm_k += move_k;
i_c = i_gemm_k % C;
i_x = (i_gemm_k / C) % Fx;
i_y = (i_gemm_k / C) / Fx;
i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px;
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
}
}
}
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_n;
ck::index_t i_c;
ck::index_t i_hi;
ck::index_t i_wi;
ck::index_t i_ho;
ck::index_t i_wo;
ck::index_t i_y;
ck::index_t i_x;
ck::index_t i_gemm_k;
ck::index_t N;
// ck::index_t K;
ck::index_t C;
ck::index_t Hi;
ck::index_t Wi;
ck::index_t Ho;
ck::index_t Wo;
ck::index_t Sy;
ck::index_t Sx;
ck::index_t Dy;
ck::index_t Dx;
ck::index_t Py;
ck::index_t Px;
ck::index_t Fy;
ck::index_t Fx;
intptr_t input_offset_acc_wi;
intptr_t input_offset_ovf_wi_acc_hi;
intptr_t input_offset_ovf_hi_acc_n;
// intptr_t input_offset_acc_c;
intptr_t input_offset_ovf_c_acc_x;
intptr_t input_offset_ovf_x_acc_y;
intptr_t src_offset; // keep this as pointer type in case we have negative offset
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
GemmN1 = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<1>{}];
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_n0 = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
ck::index_t idx_n1 = src_slice_origin_idx[Number<2>{}];
i_gemm_n = idx_n0 * GemmN1 + idx_n1;
// i_gemm_k = idx_k;
src_offset = idx_n0 * GemmK * GemmN1 + idx_k + idx_n1 * GemmN1; // Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if constexpr(BypassTransfer)
{
// TODO: weight NHWC not support this
}
else
{
const ck::index_t n_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for(index_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8)
{
ck::index_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), 8);
ck::index_t i_k_itr = k_per_block;
if(current_n_8 == 8)
{
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
while(i_k_itr >= 8)
{
avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK);
p_dst_k += 8 * 8;
p_src_k += 8;
i_k_itr -= 8;
}
if(i_k_itr & 4)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k[2 * 8 + 0] = p_src_k[0 * GemmK + 2];
p_dst_k[2 * 8 + 1] = p_src_k[1 * GemmK + 2];
p_dst_k[2 * 8 + 2] = p_src_k[2 * GemmK + 2];
p_dst_k[2 * 8 + 3] = p_src_k[3 * GemmK + 2];
p_dst_k[2 * 8 + 4] = p_src_k[4 * GemmK + 2];
p_dst_k[2 * 8 + 5] = p_src_k[5 * GemmK + 2];
p_dst_k[2 * 8 + 6] = p_src_k[6 * GemmK + 2];
p_dst_k[2 * 8 + 7] = p_src_k[7 * GemmK + 2];
p_dst_k[3 * 8 + 0] = p_src_k[0 * GemmK + 3];
p_dst_k[3 * 8 + 1] = p_src_k[1 * GemmK + 3];
p_dst_k[3 * 8 + 2] = p_src_k[2 * GemmK + 3];
p_dst_k[3 * 8 + 3] = p_src_k[3 * GemmK + 3];
p_dst_k[3 * 8 + 4] = p_src_k[4 * GemmK + 3];
p_dst_k[3 * 8 + 5] = p_src_k[5 * GemmK + 3];
p_dst_k[3 * 8 + 6] = p_src_k[6 * GemmK + 3];
p_dst_k[3 * 8 + 7] = p_src_k[7 * GemmK + 3];
p_dst_k += 4 * 8;
p_src_k += 4;
}
if(i_k_itr & 2)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k += 2 * 8;
p_src_k += 2;
}
if(i_k_itr & 1)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
}
}
else
{
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
for(index_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
{
for(index_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++)
{
ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
float v =
i_current_n_itr < GemmN ? p_src_k[i_sub_n * GemmK + i_sub_k] : .0f;
p_dst_k[i_sub_k * 8 + i_sub_n] = v;
}
}
}
p_dst += 8 * k_per_block;
p_src += 8 * GemmK;
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
ck::index_t move_n0 = src_slice_origin_step_idx[Number<0>{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset += move_k + move_n0 * GemmK * GemmN1;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_gemm_n;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck::index_t GemmN1;
ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN(
const SrcDesc& src_desc,
const Index&,
const DstDesc& dst_desc,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
DstGemmM = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
DstGemmN = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
src_offset = 0;
dst_offset = 0;
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
if constexpr(BypassTransfer)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
}
}
void SetDstSliceOrigin(const DstDesc&, const Index& dst_slice_origin_idx)
{
i_dst_gemm_m = dst_slice_origin_idx[Number<0>{}];
i_dst_gemm_n = dst_slice_origin_idx[Number<1>{}];
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
}
template <typename SrcBuffer, typename DstBuffer>
void
Run(const SrcDesc& src_desc, SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if constexpr(BypassTransfer)
{
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
}
else
{
const ck::index_t m_per_block =
src_desc.GetTransforms()[Number<0>{}]
.GetUpperLengths()[Number<0>{}]; // must be multiple of 8
const ck::index_t n_per_block =
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
const float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src += 8 * n_per_block;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc&, const Index&) {}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_dst_gemm_m;
ck::index_t i_dst_gemm_n;
ck::index_t DstGemmM;
ck::index_t DstGemmN;
intptr_t src_offset;
intptr_t dst_offset;
};
} // namespace cpu
} // namespace ck
#endif
...@@ -121,7 +121,11 @@ template <typename... Args, typename F> ...@@ -121,7 +121,11 @@ template <typename... Args, typename F>
float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args) float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
{ {
WallTimer timer; WallTimer timer;
kernel(args...);
int nwarmup = 3;
for(int i = 0; i < nwarmup; i++)
kernel(args...);
timer.Start(); timer.Start();
for(int i = 0; i < nrepeat; i++) for(int i = 0; i < nrepeat; i++)
......
...@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; / ...@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch = using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType, ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType, WeiType,
...@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 = ...@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch // clang-format off
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false),
float, DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
float, DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
float, // DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
float, DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>;
PassThrough, // clang-format on
PassThrough,
PassThrough,
ConvFwdDefault,
2,
256,
128,
64,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
512,
256,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
1024,
144,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>>;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
......
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp" #include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
namespace ck { using F32 = float;
namespace tensor_operation { using F16 = ck::half_t;
namespace cpu {
namespace device { namespace ck {
namespace device_conv2d_fwd_avx2_instance { namespace tensor_operation {
namespace cpu {
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( namespace device {
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); namespace device_conv2d_fwd_avx2_instance {
} // namespace device_conv2d_fwd_avx2_instance using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
} // namespace device
} // namespace cpu void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
} // namespace tensor_operation std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace ck
} // namespace device_conv2d_fwd_avx2_instance
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; } // namespace device
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; } // namespace cpu
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; } // namespace tensor_operation
} // namespace ck
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
{ using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
float max_diff = 1e-6; using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
for(int i = 0; i < ref.mData.size(); ++i) template <typename T>
{ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); {
if(max_diff < diff) int error_count = 0;
{ float max_diff = 1e-6;
return false;
} for(int i = 0; i < ref.mData.size(); ++i)
} {
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
return true; if(max_diff < diff)
} {
error_count++;
int main(int argc, char* argv[]) printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
{ i,
int data_type = 0; double(ref.mData[i]),
int init_method = 0; double(result.mData[i]),
diff);
// Conv shape }
ck::index_t N = 128; }
ck::index_t K = 256;
ck::index_t C = 192; return error_count == 0;
ck::index_t Y = 3; }
ck::index_t X = 3;
ck::index_t Hi = 71; float calculate_gflops() {}
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 2; int main(int argc, char* argv[])
ck::index_t conv_stride_w = 2; {
ck::index_t conv_dilation_h = 1; int data_type = 0;
ck::index_t conv_dilation_w = 1; int init_method = 0;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1; // Conv shape
ck::index_t in_right_pad_h = 1; ck::index_t N = 2;
ck::index_t in_right_pad_w = 1; ck::index_t K = 256;
ck::index_t C = 192;
if(argc == 1) ck::index_t Y = 3;
{ ck::index_t X = 3;
data_type = 1; ck::index_t Hi = 71;
init_method = 1; ck::index_t Wi = 71;
} ck::index_t conv_stride_h = 1;
else if(argc == 3) ck::index_t conv_stride_w = 1;
{ ck::index_t conv_dilation_h = 1;
data_type = std::stoi(argv[1]); ck::index_t conv_dilation_w = 1;
init_method = std::stoi(argv[2]); ck::index_t in_left_pad_h = 1;
} ck::index_t in_left_pad_w = 1;
else if(argc == 18) ck::index_t in_right_pad_h = 1;
{ ck::index_t in_right_pad_w = 1;
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); if(argc == 1)
{
N = std::stoi(argv[3]); data_type = 0;
K = std::stoi(argv[4]); init_method = 1;
C = std::stoi(argv[5]); }
Y = std::stoi(argv[6]); else if(argc == 3)
X = std::stoi(argv[7]); {
Hi = std::stoi(argv[8]); data_type = std::stoi(argv[1]);
Wi = std::stoi(argv[9]); init_method = std::stoi(argv[2]);
conv_stride_h = std::stoi(argv[10]); }
conv_stride_w = std::stoi(argv[11]); else if(argc == 18)
conv_dilation_h = std::stoi(argv[12]); {
conv_dilation_w = std::stoi(argv[13]); data_type = std::stoi(argv[1]);
in_left_pad_h = std::stoi(argv[14]); init_method = std::stoi(argv[2]);
in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]); N = std::stoi(argv[3]);
in_right_pad_w = std::stoi(argv[17]); K = std::stoi(argv[4]);
} C = std::stoi(argv[5]);
else Y = std::stoi(argv[6]);
{ X = std::stoi(argv[7]);
printf("arg1: data type (0=fp32, 1=fp16)\n"); Hi = std::stoi(argv[8]);
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); Wi = std::stoi(argv[9]);
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " conv_stride_h = std::stoi(argv[10]);
"RightPx\n"); conv_stride_w = std::stoi(argv[11]);
exit(1); conv_dilation_h = std::stoi(argv[12]);
} conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { in_left_pad_w = std::stoi(argv[15]);
using InDataType = decltype(input_type); in_right_pad_h = std::stoi(argv[16]);
using WeiDataType = decltype(wei_type); in_right_pad_w = std::stoi(argv[17]);
using OutDataType = decltype(out_type); }
using AccDataType = decltype(acc_type); else
{
using ReferenceConvBwdInstance = printf("arg1: data type (0=fp32, 1=fp16)\n");
ck::tensor_operation::host::ReferenceConvBwdData<InDataType, printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
WeiDataType, printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
OutDataType, "RightPx\n");
AccDataType, exit(1);
InElementOp, }
WeiElementOp,
OutElementOp>; auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; using WeiDataType = decltype(wei_type);
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; using OutDataType = decltype(out_type);
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; WeiDataType,
OutDataType,
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}}; InElementOp,
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}}; WeiElementOp,
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}}; OutElementOp>;
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}}; const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
auto f_host_tensor_descriptor = [](std::size_t N_, const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
std::size_t C_,
std::size_t H_, const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
std::size_t W_) { const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}), const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_})); const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
}; const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
Tensor<OutDataType> out_n_ho_wo_k(f_host_tensor_descriptor(N, K, Ho, Wo)); const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
Tensor<WeiDataType> wei_k_y_x_c(f_host_tensor_descriptor(K, C, Y, X));
Tensor<InDataType> in_n_hi_wi_c_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); auto f_host_tensor_descriptor = [](std::size_t N_,
Tensor<InDataType> in_n_hi_wi_c_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); std::size_t C_,
std::size_t H_,
std::cout << "in (N, C, Hi, Wi): " << in_n_hi_wi_c_host_result.mDesc << std::endl; std::size_t W_) {
std::cout << "wei(K, C, Y, X): " << wei_k_y_x_c.mDesc << std::endl; return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::cout << "out(N, K, Ho, Wo): " << out_n_ho_wo_k.mDesc << std::endl; std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
switch(init_method)
{ Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
case 0: break; Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
case 1: Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
case 2: std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0}); std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
break; << ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
default: << ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); << ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); << ", Threads:" << omp_get_max_threads() << std::endl;
}
switch(init_method)
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * {
in_n_hi_wi_c_device_result.mDesc.GetElementSpace(), case 0: break;
AVX2_DATA_ALIGNMENT); case 1:
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
DeviceAlignedMemCPU out_device_buf( // in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); break;
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data()); case 2:
// reset input to zero in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
in_n_hi_wi_c_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
in_device_buf.ToDevice(in_n_hi_wi_c_device_result.mData.data()); break;
case 3:
// get host result
{ #define PACK_32(v24, v16, v8, v0) \
auto ref_conv = ReferenceConvFwdInstance{}; (((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
auto ref_invoker = ref_conv.MakeInvoker();
for(auto i_n = 0; i_n < N; i_n++)
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c_host_result, {
wei_k_y_x_c, for(auto i_c = 0; i_c < C; i_c++)
out_n_ho_wo_k, {
conv_filter_strides, for(auto i_hi = 0; i_hi < Hi; i_hi++)
conv_filter_dilations, {
input_left_pads, for(auto i_wi = 0; i_wi < Wi; i_wi++)
input_right_pads, {
InElementOp{}, uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
WeiElementOp{}, in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
OutElementOp{}); }
ref_invoker.Run(ref_argument); }
} }
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr = for(auto i_k = 0; i_k < K; i_k++)
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>; {
for(auto i_c = 0; i_c < C; i_c++)
// add device Conv instances {
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; for(auto i_y = 0; i_y < Y; i_y++)
{
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> && for(auto i_x = 0; i_x < X; i_x++)
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && {
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
{ wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: }
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs); }
} }
}
if(conv_ptrs.size() <= 0) break;
{ default:
throw std::runtime_error("wrong! no device Conv instance found"); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
} wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
}
// profile device Conv instances
bool success = true; DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
for(auto& conv_ptr : conv_ptrs) AVX2_DATA_ALIGNMENT);
{ DeviceAlignedMemCPU wei_device_buf(
auto argument_ptr = conv_ptr->MakeArgumentPointer( sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), AVX2_DATA_ALIGNMENT);
N,
K, in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
C, wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
input_spatial_lengths,
filter_spatial_lengths, // get host result
output_spatial_lengths, {
conv_filter_strides, auto ref_conv = ReferenceConvFwdInstance{};
conv_filter_dilations, auto ref_invoker = ref_conv.MakeInvoker();
input_left_pads,
input_right_pads, auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
InElementOp{}, wei_k_c_y_x,
WeiElementOp{}, out_n_k_ho_wo_host_result,
OutElementOp{}); conv_filter_strides,
conv_filter_dilations,
if(conv_ptr->IsSupportedArgument(argument_ptr.get())) input_left_pads,
{ input_right_pads,
auto invoker_ptr = conv_ptr->MakeInvokerPointer(); InElementOp{},
invoker_ptr->Run(argument_ptr.get(), 1); WeiElementOp{},
OutElementOp{});
in_device_buf.FromDevice(in_n_hi_wi_c_device_result.mData.data()); ref_invoker.Run(ref_argument);
}
if(!check_out(in_n_hi_wi_c_host_result, in_n_hi_wi_c_device_result))
{ using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
success = false; DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
}
else // add device Conv instances
{ std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
} if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
} ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
else ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl; ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
} add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
} }
if(success) if(conv_ptrs.size() <= 0)
{ {
std::cout << "test conv2d fwd cpu : Pass" << std::endl; throw std::runtime_error("wrong! no device Conv instance found");
return 0; }
}
else // profile device Conv instances
{ bool success = true;
std::cout << "test conv2d fwd cpu: Fail " << std::endl; double fastest_kernel_time = std::numeric_limits<double>::max();
return -1; std::string fastest_kernel_name = "";
} double fastest_kernel_gflops = 0;
}; for(auto& conv_ptr : conv_ptrs)
{
if(data_type == 0) auto argument_ptr = conv_ptr->MakeArgumentPointer(
{ static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
return Run(F32(), F32(), F32(), F32()); static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
} static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
else if(data_type == 1) N,
{ K,
return Run(F16(), F16(), F16(), F32()); C,
} input_spatial_lengths,
else filter_spatial_lengths,
{ output_spatial_lengths,
return 1; conv_filter_strides,
} conv_filter_dilations,
} input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
double time = invoker_ptr->Run(argument_ptr.get(), 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
double gflops = (total_flop * 1e-6) / time;
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32());
}
else
{
return 1;
}
}
...@@ -226,6 +226,8 @@ int main(int argc, char** argv) ...@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static constexpr ck::index_t nDim = static constexpr ck::index_t nDim =
ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension(); ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension();
input_desc.Print();
auto threadwise_transfer = threadwise_transfer_t{input_desc, auto threadwise_transfer = threadwise_transfer_t{input_desc,
ck::make_zero_multi_index<nDim>(), ck::make_zero_multi_index<nDim>(),
input_cblock_desc, input_cblock_desc,
......
...@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk, ...@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk,
float* private_c = mat_c + tid * m * n; float* private_c = mat_c + tid * m * n;
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a; param.p_a = mat_a;
param.p_b = mat_b; param.p_b = mat_b;
param.p_c = private_c; param.p_c = private_c;
param.Kr = k; param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA); param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB); param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float); param.ldc = n * sizeof(float);
param.alpha = alpha; param.alpha = alpha;
param.accmulate_c = 0;
memset(private_c, 0, m * n * sizeof(float)); memset(private_c, 0, m * n * sizeof(float));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment