Commit afc7d431 authored by carlushuang's avatar carlushuang
Browse files

avx2 gemm now works for single thread

parent 07af8343
......@@ -13,21 +13,10 @@ namespace cpu {
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AccDataType,
typename ABlockDesc,
typename BBlockDesc,
typename CBlockDesc,
typename ABlockSliceLengths,
typename BBlockSliceLengths,
typename CBlockSliceLengths,
typename AThreadSliceLength,
typename BThreadSliceLength,
ck::index_t AThreadLoopOverDim, // thread slice loop over on block slice. 1d is enough for
// now
ck::index_t BThreadLoopOverDim,
typename CDesc,
ck::index_t KPerBlock,
......@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension();
static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CDesc::GetNumOfDimension();
using IndexA = MultiIndex<nDimA>;
using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CBlockDesc{}, IndexC{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc)
......@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBlockBuffer>
ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const
{
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t GetCLeadingElement(const CDesc& c_desc) const
{
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
}
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
return i_m;
}
}
ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return i_n;
}
else
{
// N/8 * K * 8
return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const
{
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf,
const IndexA& a_origin,
const IndexA& /* a_origin */,
const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf,
const IndexB& b_origin,
const IndexB& /* b_origin */,
const CDesc& c_desc,
CBuffer& c_buf,
const IndexC& /* c_origin */,
const CBlockDesc& c_block_desc,
CBlockBuffer& c_block_buf,
const IndexC& c_origin) const
bool is_accumulate_c = true) const
{
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
constexpr auto m_n_block_length =
ck::Sequence<ABlockSliceLengths::At(AThreadLoopOverDim),
BBlockSliceLengths::At(BThreadLoopOverDim)>{};
constexpr auto m_n_thread_length =
ck::Sequence<AThreadSliceLength::At(AThreadLoopOverDim),
BThreadSliceLength::At(BThreadLoopOverDim)>{};
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
constexpr auto m_n_access_length = m_n_block_length / m_n_thread_length;
const auto k_per_block = GetKPerBlock(a_block_desc);
const auto m_per_block = GetMPerBlock(a_block_desc);
const auto n_per_block = GetNPerBlock(b_block_desc);
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
constexpr auto ordered_m_n_access_length =
container_reorder_given_new2old(m_n_access_length, ThreadMNAccessOrder{});
ck::cpu::ThreadwiseGemmParam param;
param.Kr = k_per_block;
param.lda = lda;
param.ldb = ldb;
param.ldc = ldc;
param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0;
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
{
auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
constexpr auto a_block_idx_zeros =
typename uniform_sequence_gen<nDimA, 0>::type{}; // starting point of the block
constexpr auto b_block_idx_zeros = typename uniform_sequence_gen<nDimB, 0>::type{};
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr auto lda = GetLeadingElement(a_block_desc) * sizeof(FloatA);
constexpr auto ldb = GetLeadingElement(b_block_desc) * sizeof(FloatB);
constexpr auto ldc = GetLeadingElement(c_block_desc) * sizeof(FloatC);
for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
{
auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
ck::cpu::ThreadwiseGemmParam param;
param.Kr = KPerBlock;
param.lda = lda;
param.ldb = ldb;
param.ldc = ldc;
param.alpha = 1.0f; // TODO
static_ford<decltype(ordered_m_n_access_length)>{}([&](auto ordered_idx) {
constexpr auto origin_m_n_idx = ordered_idx.ReorderGivenOld2New(ThreadMNAccessOrder{});
constexpr auto current_m_idx =
origin_m_n_idx.At(0) * AThreadSliceLength::At(AThreadLoopOverDim);
constexpr auto current_n_idx =
origin_m_n_idx.At(1) * BThreadSliceLength::At(BThreadLoopOverDim);
constexpr auto current_mr =
ck::math::min(m_n_block_length.At(0) - current_m_idx, m_n_thread_length.At(0));
constexpr auto current_nr =
ck::math::min(m_n_block_length.At(1) - current_n_idx, m_n_thread_length.At(1));
constexpr auto a_block_idx =
a_block_idx_zeros.Modify(AThreadLoopOverDim, current_m_idx);
constexpr auto a_block_coord =
make_tensor_coordinate(a_block_desc, to_multi_index(a_origin + a_block_idx));
constexpr auto b_block_idx =
b_block_idx_zeros.Modify(BThreadLoopOverDim, current_n_idx);
constexpr auto b_block_coord =
make_tensor_coordinate(b_block_desc, to_multi_index(b_origin + b_block_idx));
constexpr auto c_block_coord =
make_tensor_coordinate(c_block_desc, to_multi_index(c_origin + origin_m_n_idx));
param.p_a = &a_block_buf.p_data_[a_block_coord.GetOffset()];
param.p_b = &b_block_buf.p_data_[b_block_coord.GetOffset()];
param.p_c = &c_block_buf.p_data_[c_block_coord.GetOffset()];
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
});
param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
}
}
}
}
};
......
......@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC,
};
enum ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop,
NHWC_GemmKLoopOverC, // not merge c*y*x, and c % k_per_block == 0
};
enum ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver,
LoopOver_MNK,
LoopOver_MKN,
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
......
......@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
......@@ -23,20 +25,21 @@ namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch>
// bool IsGemmMPadded,
// bool IsGemmNPadded,
// bool IsGemmKPadded>
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
......@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
static constexpr auto GetWeightBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
static constexpr auto GetOutputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto wei_gemm_n_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
wei_gemm_n_k_grid_desc,
ck::make_tuple(ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_n_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_n_k_grid_desc.GetLength(I1))),
make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
make_pass_through_transform(gemm_k)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
wei_gemm_padn_k_grid_desc,
ck::make_tuple(
ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
......@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
......@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = K;
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
......@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr bool UseCLocalBuffer = true;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType,
OutDataType, // OutDataType,
AccDataType, // AccDataType,
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
......@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ck::Sequence<0, 1, 2>, // BlockMNKAccessOrder,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
......@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType,
WeiDataType,
......@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.a_grid_desc_.GetElementSpaceSize());
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
......@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
......@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< ">";
<< "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
// clang-format on
return str.str();
......
......@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_gemm_param.hpp"
namespace ck {
namespace cpu {
......@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_6x16_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
......@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_6x16_Store_C%=:\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
......@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
// clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
if(param->accmulate_c){
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 2 ) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 4 ) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){
#pragma unroll
......@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
if constexpr (NonTemporalStore) {
_mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm1);
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
......@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_4x24_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_4x24_Store_C%=:\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
......@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
// clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8);
if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8);
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
if(param->accmulate_c) {
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8);
if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8);
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 ) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 2 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 3 ) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){
#pragma unroll
......@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = {
{
ThreadwiseGemm_6x16_t::Run,
ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_1x16_t::Run,
},
{
ThreadwiseGemm_5x16_t::Run,
ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_2x16_t::Run,
},
{
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_3x16_t::Run,
},
{
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_4x16_t::Run,
},
{
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_5x16_t::Run,
},
{
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_6x16_t::Run,
},
};
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 5 && in >= 0 && in <= 1);
return dispatch_table[mr][nr](param);
}
};
......@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = {
{
ThreadwiseGemm_4x24_t::Run,
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x24_t::Run,
},
{
ThreadwiseGemm_3x24_t::Run,
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x24_t::Run,
},
{
ThreadwiseGemm_2x24_t::Run,
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x24_t::Run,
},
{
ThreadwiseGemm_1x24_t::Run,
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x24_t::Run,
},
};
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{
return dispatch_table[mr][nr](param);
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 3 && in >= 0 && in <= 2);
return dispatch_table[im][in](param);
}
};
......
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_GEMM_PARAM_HPP
#define CK_THREADWISE_GEMM_PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
......@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte
float alpha;
uint32_t _pack0;
int accmulate_c; // if 1, need load C and add into current fma. if 0, direct store out c result
} __attribute__((packed));
} // namespace cpu
......
......@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
int N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
int Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
int Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
int C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
int Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
int Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
int Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
int Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
int Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
int Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
int Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
int Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
int Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
int Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
}
void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
std::cout << "src hidden:" << SrcDesc::GetNumOfHiddenDimension() << std::endl;
std::cout << "dst hidden:" << DstDesc::GetNumOfHiddenDimension() << std::endl;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
......@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
#endif
const auto src_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto src_slice_step = make_tensor_coordinate_step(
src_desc, to_multi_index(src_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
const auto dst_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto dst_slice_step = make_tensor_coordinate_step(
dst_desc, to_multi_index(dst_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
for(auto idx_id = 0; idx_id < num_access; idx_id++)
{
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = ck::cpu::vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
printf("[%s] ", is_src_valid ? "y" : "n");
print_multi_index(src_coord_.GetIndex());
printf("----");
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_(dst_vector_container.template AsType<dst_vector_t>(),
src_vector_container.template AsType<src_vector_t>());
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
printf(" -> ");
print_multi_index(dst_coord_.GetIndex());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf("\n");
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>());
// move coordinate
if(idx_id != num_access - 1)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(src_desc, src_coord_, src_slice_step);
move_tensor_coordinate(dst_desc, dst_coord_, dst_slice_step);
}
}
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
......
......@@ -121,7 +121,11 @@ template <typename... Args, typename F>
float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
{
WallTimer timer;
kernel(args...);
int nwarmup = 3;
for(int i = 0; i < nwarmup; i++)
kernel(args...);
timer.Start();
for(int i = 0; i < nrepeat; i++)
......
......@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType,
......@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
256,
128,
64,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
512,
256,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
1024,
144,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>>;
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
......
This diff is collapsed.
......@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static constexpr ck::index_t nDim =
ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension();
input_desc.Print();
auto threadwise_transfer = threadwise_transfer_t{input_desc,
ck::make_zero_multi_index<nDim>(),
input_cblock_desc,
......
......@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk,
float* private_c = mat_c + tid * m * n;
ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a;
param.p_b = mat_b;
param.p_c = private_c;
param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float);
param.alpha = alpha;
param.p_a = mat_a;
param.p_b = mat_b;
param.p_c = private_c;
param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float);
param.alpha = alpha;
param.accmulate_c = 0;
memset(private_c, 0, m * n * sizeof(float));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment