Commit afc7d431 authored by carlushuang's avatar carlushuang
Browse files

avx2 gemm now works for single thread

parent 07af8343
......@@ -13,21 +13,10 @@ namespace cpu {
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AccDataType,
typename ABlockDesc,
typename BBlockDesc,
typename CBlockDesc,
typename ABlockSliceLengths,
typename BBlockSliceLengths,
typename CBlockSliceLengths,
typename AThreadSliceLength,
typename BThreadSliceLength,
ck::index_t AThreadLoopOverDim, // thread slice loop over on block slice. 1d is enough for
// now
ck::index_t BThreadLoopOverDim,
typename CDesc,
ck::index_t KPerBlock,
......@@ -47,24 +36,14 @@ struct BlockwiseGemmAvx2_MxN
static constexpr index_t nDimA = ABlockDesc::GetNumOfDimension();
static constexpr index_t nDimB = BBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CBlockDesc::GetNumOfDimension();
static constexpr index_t nDimC = CDesc::GetNumOfDimension();
using IndexA = MultiIndex<nDimA>;
using IndexB = MultiIndex<nDimB>;
using IndexC = MultiIndex<nDimC>;
using ACoord = decltype(make_tensor_coordinate(ABlockDesc{}, IndexA{}));
using BCoord = decltype(make_tensor_coordinate(BBlockDesc{}, IndexB{}));
using CCoord = decltype(make_tensor_coordinate(CBlockDesc{}, IndexC{}));
#if 0
constexpr BlockwiseGemmAvx2_MxN(const ABlockDesc & a_block_desc, const IndexA& a_thread_origin,
const BBlockDesc & b_block_desc, const IndexB& b_thread_origin)
: a_thread_coord_(make_tensor_coordinate(a_block_desc, a_thread_origin)),
b_thread_coord_(make_tensor_coordinate(b_block_desc, b_thread_origin)),
{
}
#endif
using CCoord = decltype(make_tensor_coordinate(CDesc{}, IndexC{}));
template <typename TensorDesc>
constexpr auto GetLeadingElement(const TensorDesc& desc)
......@@ -84,79 +63,175 @@ struct BlockwiseGemmAvx2_MxN
}
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBlockBuffer>
ck::index_t GetALeadingElement(const ABlockDesc& a_block_desc) const
{
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetBLeadingElement(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t GetCLeadingElement(const CDesc& c_desc) const
{
return c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t GetKPerBlock(const ABlockDesc& a_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// M * K
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// K * M
return a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
}
}
ck::index_t GetNPerBlock(const BBlockDesc& b_block_desc) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
// N/8 * K * 8
return b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
ck::index_t
GetABlockStartOffset(const ABlockDesc& a_block_desc, const index_t i_m, const index_t) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
return i_m * a_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
else
{
return i_m;
}
}
ck::index_t
GetBBlockStartOffset(const BBlockDesc& b_block_desc, const index_t, const index_t i_n) const
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// K * N
return i_n;
}
else
{
// N/8 * K * 8
return i_n * b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
}
ck::index_t
GetCBlockStartOffset(const CDesc& c_desc, const index_t i_m, const index_t i_n) const
{
return i_m * c_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}] + i_n;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CBuffer>
void Run(const ABlockDesc& a_block_desc,
const ABlockBuffer& a_block_buf,
const IndexA& a_origin,
const IndexA& /* a_origin */,
const BBlockDesc& b_block_desc,
const BBlockBuffer& b_block_buf,
const IndexB& b_origin,
const IndexB& /* b_origin */,
const CDesc& c_desc,
CBuffer& c_buf,
const IndexC& /* c_origin */,
const CBlockDesc& c_block_desc,
CBlockBuffer& c_block_buf,
const IndexC& c_origin) const
bool is_accumulate_c = true) const
{
auto lda = GetALeadingElement(a_block_desc) * sizeof(FloatA);
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
constexpr auto m_n_block_length =
ck::Sequence<ABlockSliceLengths::At(AThreadLoopOverDim),
BBlockSliceLengths::At(BThreadLoopOverDim)>{};
constexpr auto m_n_thread_length =
ck::Sequence<AThreadSliceLength::At(AThreadLoopOverDim),
BThreadSliceLength::At(BThreadLoopOverDim)>{};
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
constexpr auto m_n_access_length = m_n_block_length / m_n_thread_length;
const auto k_per_block = GetKPerBlock(a_block_desc);
const auto m_per_block = GetMPerBlock(a_block_desc);
const auto n_per_block = GetNPerBlock(b_block_desc);
const auto m_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxMr;
const auto n_per_thread = ThreadwiseGemm_Dispatch::ThreadMaxNr;
constexpr auto ordered_m_n_access_length =
container_reorder_given_new2old(m_n_access_length, ThreadMNAccessOrder{});
ck::cpu::ThreadwiseGemmParam param;
param.Kr = k_per_block;
param.lda = lda;
param.ldb = ldb;
param.ldc = ldc;
param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0;
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{
for(ck::index_t i_m = 0; i_m < m_per_block; i_m += m_per_thread)
{
auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
constexpr auto a_block_idx_zeros =
typename uniform_sequence_gen<nDimA, 0>::type{}; // starting point of the block
constexpr auto b_block_idx_zeros = typename uniform_sequence_gen<nDimB, 0>::type{};
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
constexpr auto lda = GetLeadingElement(a_block_desc) * sizeof(FloatA);
constexpr auto ldb = GetLeadingElement(b_block_desc) * sizeof(FloatB);
constexpr auto ldc = GetLeadingElement(c_block_desc) * sizeof(FloatC);
for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
{
auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
ck::cpu::ThreadwiseGemmParam param;
param.Kr = KPerBlock;
param.lda = lda;
param.ldb = ldb;
param.ldc = ldc;
param.alpha = 1.0f; // TODO
static_ford<decltype(ordered_m_n_access_length)>{}([&](auto ordered_idx) {
constexpr auto origin_m_n_idx = ordered_idx.ReorderGivenOld2New(ThreadMNAccessOrder{});
constexpr auto current_m_idx =
origin_m_n_idx.At(0) * AThreadSliceLength::At(AThreadLoopOverDim);
constexpr auto current_n_idx =
origin_m_n_idx.At(1) * BThreadSliceLength::At(BThreadLoopOverDim);
constexpr auto current_mr =
ck::math::min(m_n_block_length.At(0) - current_m_idx, m_n_thread_length.At(0));
constexpr auto current_nr =
ck::math::min(m_n_block_length.At(1) - current_n_idx, m_n_thread_length.At(1));
constexpr auto a_block_idx =
a_block_idx_zeros.Modify(AThreadLoopOverDim, current_m_idx);
constexpr auto a_block_coord =
make_tensor_coordinate(a_block_desc, to_multi_index(a_origin + a_block_idx));
constexpr auto b_block_idx =
b_block_idx_zeros.Modify(BThreadLoopOverDim, current_n_idx);
constexpr auto b_block_coord =
make_tensor_coordinate(b_block_desc, to_multi_index(b_origin + b_block_idx));
constexpr auto c_block_coord =
make_tensor_coordinate(c_block_desc, to_multi_index(c_origin + origin_m_n_idx));
param.p_a = &a_block_buf.p_data_[a_block_coord.GetOffset()];
param.p_b = &b_block_buf.p_data_[b_block_coord.GetOffset()];
param.p_c = &c_block_buf.p_data_[c_block_coord.GetOffset()];
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
});
param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
}
}
}
}
};
......
......@@ -14,6 +14,19 @@ enum ConvolutionForwardSpecialization_t
OddC,
};
enum ConvolutionForwardGemmKSpecialization_t
{
DefaultGemmKLoop,
NHWC_GemmKLoopOverC, // not merge c*y*x, and c % k_per_block == 0
};
enum ConvolutionForwardBlockLoopOverSpecialization_t
{
DefaultBlockLoopOver,
LoopOver_MNK,
LoopOver_MKN,
};
} // namespace device
} // namespace cpu
} // namespace tensor_operation
......
......@@ -13,6 +13,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace ck {
namespace tensor_operation {
......@@ -23,20 +25,21 @@ namespace device {
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization,
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck::index_t NumDimSpatial,
ck::index_t MPerBlock, // block means data are designed to fit in cache (L1/L2/L3)
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch>
// bool IsGemmMPadded,
// bool IsGemmNPadded,
// bool IsGemmKPadded>
ck::index_t MPerThread,
ck::index_t NPerThread,
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer>
struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
......@@ -60,18 +63,89 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr bool NonTemporalStore = false;
static constexpr auto GetBlockMNKAccessOrder()
{
if constexpr(BlockLoopOverSpecialization == DefaultBlockLoopOver ||
BlockLoopOverSpecialization == LoopOver_MNK)
return ck::Sequence<0, 1, 2>{};
else if constexpr(BlockLoopOverSpecialization == LoopOver_MKN)
return ck::Sequence<0, 2, 1>{};
}
using BlockMNKAccessOrder = decltype(GetBlockMNKAccessOrder());
static constexpr auto GetThreadwiseGemm_Dispatch()
{
if constexpr(MPerThread == 4 && NPerThread == 24)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else if constexpr(MPerThread == 6 && NPerThread == 16)
{
return ck::cpu::ThreadwiseGemmAvx2_MxN_6x16_Dispatch<
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
NonTemporalStore>{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using ThreadwiseGemm_Dispatch = decltype(GetThreadwiseGemm_Dispatch());
static constexpr auto GetInputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
}
static constexpr auto GetWeightBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
}
static constexpr auto GetOutputBlockDescriptor()
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
static auto GetWeightTensorDescriptor(ck::index_t gemm_k, ck::index_t gemm_n)
{
ck::index_t gemm_n_padded =
math::integer_least_multiple(gemm_n, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const auto wei_gemm_n_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
const auto wei_gemm_padn_k_grid_desc = transform_tensor_descriptor(
wei_gemm_n_k_grid_desc,
ck::make_tuple(ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_n_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_n_k_grid_desc.GetLength(I1))),
make_tuple(make_right_pad_transform(gemm_n, gemm_n_padded - gemm_n),
make_pass_through_transform(gemm_k)),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
const auto wei_gemm_n0_k_n1_grid_desc = transform_tensor_descriptor(
wei_gemm_padn_k_grid_desc,
ck::make_tuple(
ck::make_unmerge_transform(
ck::make_tuple(wei_gemm_padn_k_grid_desc.GetLength(I0) /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)),
ck::make_pass_through_transform(wei_gemm_padn_k_grid_desc.GetLength(I1))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
......@@ -409,6 +483,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::multiplies<ck::index_t>());
}
static index_t GetGemmN(ck::index_t K)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return K;
}
static auto MakeABCGridDescriptor(ck::index_t N,
ck::index_t K,
ck::index_t C,
......@@ -423,7 +504,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using namespace ck;
const index_t GemmM = GetGemmM(N, output_spatial_lengths);
const index_t GemmN = K;
const index_t GemmN = GetGemmN(K);
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
// A:
......@@ -474,13 +555,44 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr bool UseCLocalBuffer = true;
// static constexpr bool UseCLocalBuffer = false;
using AThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC<
InDataType,
InDataType,
AGridDesc,
decltype(GetInputBlockDescriptor()),
InElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using BThreadwiseCopy =
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC<
WeiDataType,
WeiDataType,
BGridDesc,
decltype(GetWeightBlockDescriptor()),
WeiElementwiseOperation,
false,
ConvForwardSpecialization,
GemmKSpecialization>;
using CThreadwiseCopy = ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN<
OutDataType,
OutDataType,
CGridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
ConvForwardSpecialization,
GemmKSpecialization>;
using GridwiseGemm =
ck::cpu::GridwiseGemmAvx2_MxN<InDataType, // InDataType,
WeiDataType, // WeiDataType,
OutDataType, // OutDataType,
AccDataType, // AccDataType,
AGridDesc, // AGridDesc,
BGridDesc, // BGridDesc,
CGridDesc, // CGridDesc,
......@@ -491,8 +603,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
NPerBlock, // NPerBlock,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ck::Sequence<0, 1, 2>, // BlockMNKAccessOrder,
AThreadwiseCopy, // AThreadwiseCopy
BThreadwiseCopy, // BThreadwiseCopy
CThreadwiseCopy, // CThreadwiseCopy
BlockMNKAccessOrder, // BlockMNKAccessOrder,
ck::Sequence<0, 1>, // ThreadMNAccessOrder
UseALocalBuffer, // UseALocalBuffer
UseBLocalBuffer, // UseBLocalBuffer
UseCLocalBuffer // UseCLocalBuffer
>;
......@@ -580,6 +697,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
throw std::runtime_error("wrong! GridwiseGemmAvx2_MxN has invalid setting");
}
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const auto kernel = ck::cpu::kernel_gemm_avx_mxn<GridwiseGemm,
InDataType,
WeiDataType,
......@@ -591,21 +710,24 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BElementwiseOperation,
CElementwiseOperation>;
float ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
float ave_time = 0;
if(nrepeat != 1)
ave_time = launch_and_time_cpu_kernel(kernel,
nrepeat,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.c_grid_desc_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset(arg.p_c_grid_, 0, arg.a_grid_desc_.GetElementSpaceSize());
memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel(kernel,
arg.p_a_grid_,
......@@ -659,6 +781,13 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
if(!(arg.Conv_C_ % KPerBlock == 0))
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_);
}
......@@ -748,16 +877,25 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::string GetTypeString() const override
{
auto str = std::stringstream();
auto str = std::stringstream();
auto string_local_buffer = [](bool is_local_buffer) {
if(is_local_buffer)
return "L";
else
return "G";
};
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock
<< ">";
<< "DFwdAvx2_NHWC_KYXC"
<<"_FS"<< static_cast<int>(ConvForwardSpecialization)
<<"_KS"<< static_cast<int>(GemmKSpecialization)
<<"_BS"<< static_cast<int>(BlockLoopOverSpecialization)
<< "_BT" << MPerBlock << "x" << NPerBlock << "x" << KPerBlock
<< "_TT" << MPerThread << "x" << NPerThread
<< "_A" << string_local_buffer(UseALocalBuffer)
<< "_B" << string_local_buffer(UseBLocalBuffer)
<< "_C" << string_local_buffer(UseCLocalBuffer)
;
// clang-format on
return str.str();
......
......@@ -7,7 +7,9 @@
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <unistd.h>
namespace ck {
namespace cpu {
......@@ -46,7 +48,6 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AccDataType,
typename AGridDesc,
typename BGridDesc,
typename CGridDesc,
......@@ -57,334 +58,92 @@ template <typename FloatA,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
typename ThreadwiseGemm_Dispatch,
typename AThreadwiseCopy,
typename BThreadwiseCopy,
typename CThreadwiseCopy,
typename BlockMNKAccessOrder, // how we accss gemm MNK to better fit in cache
typename ThreadMNAccessOrder, // how we acces gemm MN to utilize micro kernel
bool UseALocalBuffer,
bool UseBLocalBuffer,
bool UseCLocalBuffer // if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer. if false, will write to C directly
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
struct GridwiseGemmAvx2_MxN
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static constexpr index_t MemAlignmentByte = 32; // 256bit
static constexpr auto GetABlockDescriptor()
static auto GetABlockDescriptor(const ck::index_t m_per_blk, const ck::index_t k_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
constexpr auto a_block_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
auto a_block_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, k_per_blk));
return a_block_desc_m_k;
}
else
{
// A : K, M
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(KPerBlock,
auto a_block_desc_k_m = make_naive_tensor_descriptor_packed(
make_tuple(k_per_blk,
math::integer_least_multiple(
MPerBlock, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
m_per_blk, ThreadwiseGemm_Dispatch::MatrixAMinVectorSize)));
return a_block_desc_k_m;
}
}
static constexpr auto GetBBlockDescriptor()
static auto GetBBlockDescriptor(const ck::index_t k_per_blk, const ck::index_t n_per_blk)
{
// n_per_blk should be 8x
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_packed(
make_tuple(KPerBlock,
math::integer_least_multiple(
NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize)));
auto b_block_desc_k_n =
make_naive_tensor_descriptor_packed(make_tuple(k_per_blk, n_per_blk));
return b_block_desc_k_n;
}
else
{
// B : N/8, K, N8
constexpr auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(NPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
KPerBlock,
auto b_block_desc_n0_k_n1 = make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(n_per_blk, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
k_per_blk,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_block_desc_n0_k_n1;
}
}
static constexpr auto GetABlockSliceLength()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::Sequence<MPerBlock, KPerBlock>{};
}
else
{
// A : K, M
return ck::Sequence<KPerBlock, MPerBlock>{};
}
}
static constexpr auto GetBBlockSliceLength()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::Sequence<KPerBlock, NPerBlock>{};
}
else
{
// B : N/8, K, N88;
return ck::Sequence<NPerBlock / ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize>{};
}
}
static constexpr auto GetABlockDimAccessOrder() { return ck::Sequence<0, 1>{}; }
static constexpr auto GetBBlockDimAccessOrder()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::Sequence<0, 1>{};
}
else
{
// B : N/8, K, N88;
return ck::Sequence<0, 1, 2>{};
}
}
static constexpr auto GetABlockMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(0, KPerBlock);
}
else
{
// A : K, M
return ck::make_multi_index(KPerBlock, 0);
}
}
static constexpr auto GetBBlockMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(KPerBlock, 0);
}
else
{
// B : N/8, K, N88;
return ck::make_multi_index(0, KPerBlock, 0);
}
}
#if 0
static constexpr auto GetAThreadDiscriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout, ck::tensor_layout::gemm::RowMajor>::value){
// A : M, K
constexpr auto a_thread_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock));
return a_thread_desc_m_k;
} else {
// A : K, M
constexpr auto a_thread_desc_k_m = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr));
return a_thread_desc_k_m;
}
}
static constexpr auto GetBThreadDescriptor()
{
if constexpr (std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout, ck::tensor_layout::gemm::RowMajor>::value){
// B : K, N
constexpr auto b_thread_desc_k_n = make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr));
return b_thread_desc_k_n;
} else {
// B : N/8, K, N8
constexpr auto b_thread_desc_n_k_n8 = make_naive_tensor_descriptor_packed(make_tuple(math::integer_divide_ceil(ThreadwiseGemm_Dispatch::ThreadMaxNr, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return b_thread_desc_n_k_n8;
}
}
#endif
static constexpr auto GetAThreadSliceLength()
static auto GetCBlockDescriptor(const ck::index_t m_per_blk, const ck::index_t n_per_blk)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::Sequence<ThreadwiseGemm_Dispatch::ThreadMaxMr, KPerBlock>{};
}
else
{
// A : K, M
return ck::Sequence<KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxMr>{};
}
}
static constexpr auto GetBThreadSliceLength()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::Sequence<KPerBlock, ThreadwiseGemm_Dispatch::ThreadMaxNr>{};
}
else
{
// B : N/8, K, N88;
return ck::Sequence<ThreadwiseGemm_Dispatch::ThreadMaxNr /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
KPerBlock,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize>{};
}
return make_naive_tensor_descriptor_packed(make_tuple(m_per_blk, n_per_blk));
}
static constexpr auto GetAThreadMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return ck::make_multi_index(ThreadwiseGemm_Dispatch::ThreadMaxMr, 0);
}
else
{
// A : K, M
return ck::make_multi_index(0, ThreadwiseGemm_Dispatch::ThreadMaxMr);
}
}
static constexpr auto GetBThreadMoveFwdStep()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return ck::make_multi_index(0, ThreadwiseGemm_Dispatch::ThreadMaxNr);
}
else
{
// B : N/8, K, N88;
return ck::Sequence<ThreadwiseGemm_Dispatch::ThreadMaxNr /
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize,
0,
0>{};
}
}
static constexpr ck::index_t GetAThreadLoopOverDim()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// A : M, K
return 0;
}
else
{
// A : K, M
return 1;
}
}
static constexpr ck::index_t GetBThreadLoopOverDim()
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixBLayout,
ck::tensor_layout::gemm::RowMajor>::value)
{
// B : K, N
return 1;
}
else
{
// B : N/8, K, N88;
return 0;
}
}
static constexpr auto GetCBlockDescriptor()
{
if constexpr(UseCLocalBuffer)
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock)); // TODO:
}
}
static constexpr auto GetCBlockSliceLength() { return ck::Sequence<MPerBlock, NPerBlock>{}; }
static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const BGridDesc& b_grid_desc,
const CGridDesc& c_grid_desc)
{
#if 0
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool is_valid = true;
const auto GemmN = c_grid_desc.GetLength(I1);
if constexpr(UseCLocalBuffer)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value && NPerBlock < GemmN)
is_valid &= false;
}
else
{
return false;
// TODO: need check c grid is simple transform?
if(GemmN % 8 != 0)
is_valid &= false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
return false;
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
return is_valid;
}
static void Run(const FloatA* __restrict__ p_a_grid,
......@@ -397,178 +156,149 @@ struct GridwiseGemmAvx2_MxN
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
ck::index_t m_per_block;
ck::index_t n_per_block;
ck::index_t k_per_block;
if constexpr(MPerBlock == 0 && NPerBlock == 0 && KPerBlock == 0) {}
else
{
m_per_block = MPerBlock;
n_per_block = NPerBlock;
k_per_block = KPerBlock;
}
const auto M = a_grid_desc.GetLength(I0);
const auto N = b_grid_desc.GetLength(I1);
const auto K = b_grid_desc.GetLength(I0);
const ck::index_t grid_m = math::integer_divide_ceil(M, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(N, n_per_block);
const ck::index_t grid_size = grid_m * grid_n;
constexpr auto a_block_desc = GetABlockDescriptor();
constexpr auto a_block_slice_length = GetABlockSliceLength();
constexpr auto a_block_copy_dim = decltype(a_block_slice_length)::Size();
constexpr auto a_dim_access_order = GetABlockDimAccessOrder();
constexpr auto a_block_move_step = GetABlockMoveFwdStep();
constexpr auto a_thread_slice_length = GetAThreadSliceLength();
constexpr auto a_thread_loop_over_dim = GetAThreadLoopOverDim();
constexpr auto b_block_desc = GetBBlockDescriptor();
constexpr auto b_block_slice_length = GetBBlockSliceLength();
constexpr auto b_block_copy_dim = decltype(b_block_slice_length)::Size();
constexpr auto b_dim_access_order = GetBBlockDimAccessOrder();
constexpr auto b_block_move_step = GetBBlockMoveFwdStep();
constexpr auto b_thread_slice_length = GetBThreadSliceLength();
constexpr auto b_thread_loop_over_dim = GetBThreadLoopOverDim();
constexpr auto c_block_desc = GetCBlockDescriptor();
constexpr auto c_block_slice_length = GetCBlockSliceLength();
constexpr auto c_block_move_step = ck::make_multi_index(0, NPerBlock);
auto a_threadwise_copy = ck::cpu::ThreadwiseTensorSliceTransferAvx2<
FloatA, // SrcData
FloatA, // DstData
decltype(a_grid_desc), // SrcDesc
decltype(a_block_desc), // DstDesc
AElementwiseOperation, // ElementwiseOperation
decltype(a_block_slice_length), // SliceLengths
decltype(a_dim_access_order), // DimAccessOrder
1, // VectorDim
1, // ScalarPerVector
ck::InMemoryDataOperationEnum_t::Set, // InMemoryDataOperationEnum_t
false, // SrcResetCoordinateAfterRun
true // DstResetCoordinateAfterRun
>(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
a_block_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy = ck::cpu::ThreadwiseTensorSliceTransferAvx2<
FloatB, // SrcData
FloatB, // DstData
decltype(b_grid_desc), // SrcDesc
decltype(b_block_desc), // DstDesc
BElementwiseOperation, // ElementwiseOperation
decltype(b_block_slice_length), // SliceLengths
decltype(b_dim_access_order), // DimAccessOrder
1, // VectorDim
1, // ScalarPerVector
ck::InMemoryDataOperationEnum_t::Set, // InMemoryDataOperationEnum_t
false, // SrcResetCoordinateAfterRun
true // DstResetCoordinateAfterRun
>(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
b_block_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy = ck::cpu::ThreadwiseTensorSliceTransferAvx2<
FloatC, // SrcData
FloatC, // DstData
decltype(c_block_desc), // SrcDesc
decltype(c_grid_desc), // DstDesc
BElementwiseOperation, // ElementwiseOperation
ck::Sequence<MPerBlock, NPerBlock>, // SliceLengths
ck::Sequence<0, 1>, // DimAccessOrder
1, // VectorDim
1, // ScalarPerVector
ck::InMemoryDataOperationEnum_t::Set, // InMemoryDataOperationEnum_t
true, // SrcResetCoordinateAfterRun
false // DstResetCoordinateAfterRun
>(c_block_desc,
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(MPerBlock * KPerBlock * sizeof(FloatA), MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(KPerBlock * NPerBlock * sizeof(FloatB), MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(MPerBlock * NPerBlock * sizeof(FloatC), MemAlignmentByte);
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
ck::index_t m_per_block = MPerBlock;
ck::index_t n_per_block = NPerBlock;
ck::index_t k_per_block = KPerBlock;
const auto GemmM = c_grid_desc.GetLength(I0);
const auto GemmN = c_grid_desc.GetLength(I1);
const auto GemmK = a_grid_desc.GetLength(I1);
constexpr auto a_block_copy_dim = AGridDesc::GetNumOfDimension();
constexpr auto b_block_copy_dim = BGridDesc::GetNumOfDimension();
auto a_threadwise_copy = AThreadwiseCopy(a_grid_desc,
ck::make_zero_multi_index<a_block_copy_dim>(),
GetABlockDescriptor(m_per_block, k_per_block),
ck::make_zero_multi_index<a_block_copy_dim>(),
AElementwiseOperation{});
auto b_threadwise_copy = BThreadwiseCopy(b_grid_desc,
ck::make_zero_multi_index<b_block_copy_dim>(),
GetBBlockDescriptor(k_per_block, n_per_block),
ck::make_zero_multi_index<b_block_copy_dim>(),
BElementwiseOperation{});
auto c_threadwise_copy = CThreadwiseCopy(GetCBlockDescriptor(m_per_block, n_per_block),
ck::make_zero_multi_index<2>(),
c_grid_desc,
ck::make_zero_multi_index<2>(),
CElementwiseOperation{});
DeviceAlignedMemCPU a_block_mem(m_per_block * k_per_block * sizeof(FloatA),
MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC),
MemAlignmentByte);
auto a_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatA*>(p_a_grid), a_grid_desc.GetElementSpaceSize());
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
auto b_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<const FloatB*>(p_b_grid), b_grid_desc.GetElementSpaceSize());
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
auto c_grid_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatC*>(p_c_grid), c_grid_desc.GetElementSpaceSize());
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatA*>(a_block_mem.mpDeviceBuf),
a_block_mem.mMemSize / sizeof(FloatA));
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
auto b_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
reinterpret_cast<FloatB*>(b_block_mem.mpDeviceBuf),
b_block_mem.mMemSize / sizeof(FloatB));
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum_t::Global>(
reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf),
c_block_mem.mMemSize / sizeof(FloatC));
auto blockwise_gemm =
BlockwiseGemmAvx2_MxN<FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
AccDataType, // AccDataType,
decltype(a_block_desc), // ABlockDesc,
decltype(b_block_desc), // BBlockDesc,
decltype(c_block_desc), // CBlockDesc,
decltype(a_block_slice_length), // ABlockSliceLengths,
decltype(b_block_slice_length), // BBlockSliceLengths,
decltype(c_block_slice_length), // CBlockSliceLengths,
decltype(a_thread_slice_length), // AThreadSliceLength,
decltype(b_thread_slice_length), // BThreadSliceLength,
a_thread_loop_over_dim, // AThreadLoopOverDim, // thread slice
// loop over on block slice. 1d is enough
// for now
b_thread_loop_over_dim, // BThreadLoopOverDim,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
auto c_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
UseCLocalBuffer ? reinterpret_cast<FloatC*>(c_block_mem.mpDeviceBuf)
: reinterpret_cast<FloatC*>(p_c_grid),
UseCLocalBuffer ? c_block_mem.mMemSize / sizeof(FloatC)
: c_grid_desc.GetElementSpaceSize());
auto blockwise_gemm = BlockwiseGemmAvx2_MxN<
FloatA, // FloatA,
FloatB, // FloatB,
FloatC, // FloatC,
decltype(GetABlockDescriptor(m_per_block, k_per_block)), // ABlockDesc,
decltype(GetBBlockDescriptor(k_per_block, n_per_block)), // BBlockDesc,
decltype(GetCBlockDescriptor(m_per_block, n_per_block)), // CBlockDesc,
KPerBlock, // KPerBlock,
ThreadwiseGemm_Dispatch, // ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder>{}; // ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering
//
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(0, k_per_block, 0);
const ck::index_t grid_m = math::integer_divide_ceil(GemmM, m_per_block);
const ck::index_t grid_n = math::integer_divide_ceil(GemmN, n_per_block);
const ck::index_t grid_size = grid_m * grid_n;
// This version does not consider K panel re-usage. simple for openmp
#pragma omp parallel for
for(ck::index_t gid = 0; gid < grid_size; gid++)
{
ck::index_t i_mc = (gid / grid_n) * m_per_block;
ck::index_t i_nc = (gid % grid_n) * n_per_block;
ck::index_t mc_size = ck::math::min(M - i_mc, m_per_block);
ck::index_t nc_size = ck::math::min(N - i_nc, n_per_block);
// pack_b
b_threadwise_copy.RunGeneric(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_move_step);
if(i_nc == 0)
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
b_threadwise_copy.SetSrcSliceOrigin(
b_grid_desc,
ck::make_multi_index(math::integer_divide_ceil(
i_nc, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0));
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
}
else
{
// pack_a
a_threadwise_copy.RunGeneric(
a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_move_step);
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
for(ck::index_t i_kc = 0; i_kc < K; i_kc += k_per_block)
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
ck::index_t kc_size = ck::math::min(K - i_kc, k_per_block);
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
// printf("==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d, %d)\n", i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
// printf("A ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//}
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__);
blockwise_gemm.Run(a_block_desc,
a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
......@@ -577,14 +307,108 @@ struct GridwiseGemmAvx2_MxN
make_zero_multi_index<b_block_copy_dim>(),
c_block_desc,
c_block_buf,
make_zero_multi_index<2>());
make_zero_multi_index<2>(),
i_kc != 0);
// printf("[%d] 2222 \n",__LINE__);
if((i_kc + k_per_block) < GemmK)
{
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
}
// printf("[%d] 2222 \n",__LINE__);
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
}
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
// i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
}
else if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value)
{
auto a_move_k_step = ck::make_multi_index(0, k_per_block);
auto b_move_k_step = ck::make_multi_index(
math::integer_divide_ceil(n_per_block,
ThreadwiseGemm_Dispatch::MatrixBMinVectorSize),
0,
0);
// only parallel in gemm m dim
#pragma omp parallel for
for(ck::index_t i_mc = 0; i_mc < GemmM; i_mc += m_per_block)
{
ck::index_t mc_size = ck::math::min(GemmM - i_mc, m_per_block);
a_threadwise_copy.SetSrcSliceOrigin(a_grid_desc, ck::make_multi_index(i_mc, 0));
for(ck::index_t i_kc = 0; i_kc < GemmK; i_kc += k_per_block)
{
c_threadwise_copy.RunGeneric(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
c_threadwise_copy.MoveDstSliceWindow(c_grid_desc, c_block_move_step);
ck::index_t kc_size = ck::math::min(GemmK - i_kc, k_per_block);
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.SetSrcSliceOrigin(b_grid_desc,
ck::make_multi_index(0, i_kc, 0));
// TODO: if use local C buffer, then this nc loop need to loop only once
for(ck::index_t i_nc = 0; i_nc < GemmN; i_nc += n_per_block)
{
ck::index_t nc_size =
ck::math::min(GemmN - i_nc, n_per_block); // TODO: nc need be 8x
nc_size = math::integer_least_multiple(
nc_size, ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
auto c_block_desc =
UseCLocalBuffer ? GetCBlockDescriptor(mc_size, nc_size) : c_grid_desc;
if constexpr(!UseCLocalBuffer)
{
c_threadwise_copy.SetSrcSliceOrigin(c_block_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
blockwise_gemm.Run(a_block_desc,
a_block_buf,
make_zero_multi_index<a_block_copy_dim>(),
b_block_desc,
b_block_buf,
make_zero_multi_index<b_block_copy_dim>(),
c_block_desc,
c_block_buf,
make_zero_multi_index<2>(),
i_kc != 0);
if((i_nc + n_per_block) < GemmN)
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
if constexpr(UseCLocalBuffer)
{
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck::make_multi_index(i_mc, i_nc));
c_threadwise_copy.Run(
c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
}
}
if((i_kc + k_per_block) < GemmK)
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
}
}
}
......
......@@ -7,7 +7,7 @@
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
#include "threadwise_gemm_param.hpp"
namespace ck {
namespace cpu {
......@@ -294,6 +294,9 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_6x16_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
......@@ -307,6 +310,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_6x16_Store_C%=:\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
......@@ -424,18 +428,33 @@ struct ThreadwiseGemmAvx2_MxN_6x16
};
// clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
if(param->accmulate_c){
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 2 ) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 4 ) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){
#pragma unroll
......@@ -532,6 +551,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
if constexpr (NonTemporalStore) {
_mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm1);
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
......@@ -830,19 +850,23 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
// " vaddps (%%rax), %%ymm0, %%ymm0 \n"
// ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
// ".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
// ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
// ".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
// ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
// ".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
// ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
// ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
// ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"mov 60(%[m_param]), %%edi\n" // accmulate_c
"test %%edi, %%edi\n"
"je L_GemmAvx2_MxN_4x24_Store_C%=\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
"L_GemmAvx2_MxN_4x24_Store_C%=:\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
......@@ -960,18 +984,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24
};
// clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8);
if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8);
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
if(param->accmulate_c) {
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8);
if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8);
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
} else {
ymm0 = _mm256_xor_ps(ymm0, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_xor_ps(ymm1, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_xor_ps(ymm2, ymm2);
if constexpr (Mr > 1 ) ymm3 = _mm256_xor_ps(ymm3, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_xor_ps(ymm4, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_xor_ps(ymm5, ymm5);
if constexpr (Mr > 2 ) ymm6 = _mm256_xor_ps(ymm6, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_xor_ps(ymm7, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_xor_ps(ymm8, ymm8);
if constexpr (Mr > 3 ) ymm9 = _mm256_xor_ps(ymm9, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_xor_ps(ymm10, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_xor_ps(ymm11, ymm11);
}
while (Kr > 4){
#pragma unroll
......@@ -1221,33 +1260,36 @@ struct ThreadwiseGemmAvx2_MxN_6x16_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[6][2] = {
{
ThreadwiseGemm_6x16_t::Run,
ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_1x16_t::Run,
},
{
ThreadwiseGemm_5x16_t::Run,
ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_2x16_t::Run,
},
{
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_3x16_t::Run,
},
{
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_4x16_t::Run,
},
{
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_5x8_t::Run,
ThreadwiseGemm_5x16_t::Run,
},
{
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_6x8_t::Run,
ThreadwiseGemm_6x16_t::Run,
},
};
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 5 && in >= 0 && in <= 1);
return dispatch_table[mr][nr](param);
}
};
......@@ -1371,30 +1413,33 @@ struct ThreadwiseGemmAvx2_MxN_4x24_Dispatch
static constexpr pThreadwiseGemmAvx2Run dispatch_table[4][3] = {
{
ThreadwiseGemm_4x24_t::Run,
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x24_t::Run,
},
{
ThreadwiseGemm_3x24_t::Run,
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x24_t::Run,
},
{
ThreadwiseGemm_2x24_t::Run,
ThreadwiseGemm_2x16_t::Run,
ThreadwiseGemm_2x8_t::Run,
ThreadwiseGemm_3x8_t::Run,
ThreadwiseGemm_3x16_t::Run,
ThreadwiseGemm_3x24_t::Run,
},
{
ThreadwiseGemm_1x24_t::Run,
ThreadwiseGemm_1x16_t::Run,
ThreadwiseGemm_1x8_t::Run,
ThreadwiseGemm_4x8_t::Run,
ThreadwiseGemm_4x16_t::Run,
ThreadwiseGemm_4x24_t::Run,
},
};
static void Run(ThreadwiseGemmParam* param, index_t mr, index_t nr)
{
return dispatch_table[mr][nr](param);
index_t im = mr - 1;
index_t in = (nr >> 3) - 1;
assert(im >= 0 && im <= 3 && in >= 0 && in <= 2);
return dispatch_table[im][in](param);
}
};
......
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#ifndef CK_THREADWISE_GEMM_PARAM_HPP
#define CK_THREADWISE_GEMM_PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
......@@ -17,7 +17,7 @@ struct ThreadwiseGemmParam
uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte
float alpha;
uint32_t _pack0;
int accmulate_c; // if 1, need load C and add into current fma. if 0, direct store out c result
} __attribute__((packed));
} // namespace cpu
......
......@@ -53,6 +53,42 @@ struct ThreadwiseTensorSliceTransferAvx2
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
int N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
int Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
int Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
int C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
int Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
int Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
int Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
int Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
int Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
int Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
int Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
int Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
int Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
int Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
}
void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -87,6 +123,10 @@ struct ThreadwiseTensorSliceTransferAvx2
// std::cout<<"num_access:"<<num_access<<std::endl;
std::cout << "src hidden:" << SrcDesc::GetNumOfHiddenDimension() << std::endl;
std::cout << "dst hidden:" << DstDesc::GetNumOfHiddenDimension() << std::endl;
#if 0
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
......@@ -148,6 +188,75 @@ struct ThreadwiseTensorSliceTransferAvx2
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
#endif
const auto src_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto src_slice_step = make_tensor_coordinate_step(
src_desc, to_multi_index(src_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
const auto dst_slice_idx_zeros = typename uniform_sequence_gen<nDim, 0>::type{};
const auto dst_slice_step = make_tensor_coordinate_step(
dst_desc, to_multi_index(dst_slice_idx_zeros.Modify(Number<nDim - 1>{}, Number<1>{})));
for(auto idx_id = 0; idx_id < num_access; idx_id++)
{
using src_vector_type = ck::cpu::vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = ck::cpu::vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
printf("[%s] ", is_src_valid ? "y" : "n");
print_multi_index(src_coord_.GetIndex());
printf("----");
// print_multi_index(src_coord_.GetHiddenIndex());
// printf(":%d", src_coord_.GetOffset());
// printf("\n");
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
// static_for<0, ScalarPerVector, 1>{}([&](auto i) {
// element_op_(dst_vector_container.template AsType<DstData>()(i),
// src_vector_container.template AsType<SrcData>()[i]);
// });
element_op_(dst_vector_container.template AsType<dst_vector_t>(),
src_vector_container.template AsType<src_vector_t>());
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
printf(" -> ");
print_multi_index(dst_coord_.GetIndex());
// printf(":%d", dst_coord_.GetOffset());
// printf(", src:0x%x, dst:0x%x",
// *reinterpret_cast<uint32_t*>(&src_vector_container.template AsType<src_vector_t>()),
// *reinterpret_cast<uint32_t*>(&dst_vector_container.template
// AsType<dst_vector_t>()));
printf("\n");
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>());
// move coordinate
if(idx_id != num_access - 1)
{
// constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(src_desc, src_coord_, src_slice_step);
move_tensor_coordinate(dst_desc, dst_coord_, dst_slice_step);
}
}
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
......
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp"
#include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
namespace ck {
namespace cpu {
namespace avx2_util {
inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0, _mm256_loadu_ps(p_src + 0));
_mm256_storeu_ps(p_dst + 8, _mm256_loadu_ps(p_src + 8));
p_dst += 16;
p_src += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst, _mm256_loadu_ps(p_src));
p_dst += 8;
p_src += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, _mm_loadu_ps(p_src));
p_dst += 4;
p_src += 4;
}
if(i_n & 2)
{
_mm_storeu_si64(p_dst, _mm_loadu_si64(p_src));
p_dst += 2;
p_src += 2;
}
if(i_n & 1)
{
*p_dst = *p_src;
}
}
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
__m256 ymm = _mm256_set1_ps(*reinterpret_cast<const float*>(&value));
__m128 xmm = _mm_set1_ps(*reinterpret_cast<const float*>(&value));
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0, ymm);
_mm256_storeu_ps(p_dst + 8, ymm);
p_dst += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst, ymm);
p_dst += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, xmm);
p_dst += 4;
}
if(i_n & 2)
{
_mm_storeu_si64(p_dst, xmm);
p_dst += 2;
}
if(i_n & 1)
{
*p_dst = *reinterpret_cast<const float*>(&value);
}
}
inline void
transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_t stride_src)
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256 r0, r1, r2, r3, r4, r5, r6, r7;
__m256 t0, t1, t2, t3, t4, t5, t6, t7;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
r0 = _mm256_loadu_ps(p_src + 0 * stride_src);
r1 = _mm256_loadu_ps(p_src + 1 * stride_src);
r2 = _mm256_loadu_ps(p_src + 2 * stride_src);
r3 = _mm256_loadu_ps(p_src + 3 * stride_src);
r4 = _mm256_loadu_ps(p_src + 4 * stride_src);
r5 = _mm256_loadu_ps(p_src + 5 * stride_src);
r6 = _mm256_loadu_ps(p_src + 6 * stride_src);
r7 = _mm256_loadu_ps(p_src + 7 * stride_src);
t0 = _mm256_unpacklo_ps(r0, r1);
t1 = _mm256_unpackhi_ps(r0, r1);
t2 = _mm256_unpacklo_ps(r2, r3);
t3 = _mm256_unpackhi_ps(r2, r3);
t4 = _mm256_unpacklo_ps(r4, r5);
t5 = _mm256_unpackhi_ps(r4, r5);
t6 = _mm256_unpacklo_ps(r6, r7);
t7 = _mm256_unpackhi_ps(r6, r7);
r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0));
r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2));
r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0));
r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2));
r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0));
r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2));
r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0));
r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2));
t0 = _mm256_permute2f128_ps(r0, r4, 0x20);
t1 = _mm256_permute2f128_ps(r1, r5, 0x20);
t2 = _mm256_permute2f128_ps(r2, r6, 0x20);
t3 = _mm256_permute2f128_ps(r3, r7, 0x20);
t4 = _mm256_permute2f128_ps(r0, r4, 0x31);
t5 = _mm256_permute2f128_ps(r1, r5, 0x31);
t6 = _mm256_permute2f128_ps(r2, r6, 0x31);
t7 = _mm256_permute2f128_ps(r3, r7, 0x31);
_mm256_storeu_ps(p_dst + 0 * stride_dst, t0);
_mm256_storeu_ps(p_dst + 1 * stride_dst, t1);
_mm256_storeu_ps(p_dst + 2 * stride_dst, t2);
_mm256_storeu_ps(p_dst + 3 * stride_dst, t3);
_mm256_storeu_ps(p_dst + 4 * stride_dst, t4);
_mm256_storeu_ps(p_dst + 5 * stride_dst, t5);
_mm256_storeu_ps(p_dst + 6 * stride_dst, t6);
_mm256_storeu_ps(p_dst + 7 * stride_dst, t7);
}
} // namespace avx2_util
using ConvolutionForwardSpecialization_t =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t;
using ConvolutionForwardGemmKSpecialization_t =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t;
// assume input -> a matrix
// assume input -> MC * KC
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC(
const SrcDesc& src_desc,
const Index&,
const DstDesc&,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
N = 1;
Hi = 1;
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; // gemm_m
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; // gemm_k
Ho = 1;
Wo = Wi;
Fy = 1;
Fx = 1;
Dy = 1;
Sy = 1;
Dx = 1;
Sx = 1;
Py = 0;
Px = 0;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
Ho = src_desc.GetTransforms()[Number<2>{}].GetUpperLengths()[Number<0>{}];
Wo = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<0>{}];
Fy = 1;
Fx = 1;
Dy = 1;
Sy = src_desc.GetTransforms()[Number<2>{}].coefficients_[Number<0>{}];
Dx = 1;
Sx = src_desc.GetTransforms()[Number<3>{}].coefficients_[Number<0>{}];
Py = 0;
Px = 0;
}
else
{
N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
}
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi = Sx * C;
input_offset_ovf_wi_acc_hi = Sy * Wi * C - Wo * Sx * C;
input_offset_ovf_hi_acc_n = Hi * Wi * C - Ho * Sy * Wi * C;
// input_offset_acc_c = 1;
input_offset_ovf_c_acc_x = Dx * C - C;
input_offset_ovf_x_acc_y = Dy * Wi * C - Fx * Dx * C;
src_offset = -Py * Wi * C - Px * C;
i_n = 0;
i_c = 0;
i_hi = -Py;
i_wi = -Px;
i_ho = 0;
i_wo = 0;
i_y = 0;
i_x = 0;
i_gemm_k = 0;
#if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
N,
Hi,
Wi,
C,
Ho,
Wo,
Fy,
Fx,
Dy,
Sy,
Dx,
Sx,
Py,
Px);
#endif
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_m = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
i_wi = idx_m;
i_c = idx_k;
src_offset = i_wi * C + i_c;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
i_wo = idx_m % Wo;
i_ho = (idx_m / Wo) % Ho;
i_n = (idx_m / Wo) / Ho;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c = idx_k;
i_x = 0;
i_y = 0;
i_hi = i_ho * Sy;
i_wi = i_wo * Sx;
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k;
}
else
{
i_wo = idx_m % Wo;
i_ho = (idx_m / Wo) % Ho;
i_n = (idx_m / Wo) / Ho;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if(idx_k == 0)
{
i_c = 0;
i_x = 0;
i_y = 0;
i_hi = i_ho * Sy - Py;
i_wi = i_wo * Sx - Px;
}
else
{
i_c = idx_k % C;
i_x = (idx_k / C) % Fx;
i_y = (idx_k / C) / Fx;
i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px;
}
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
}
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if constexpr(BypassTransfer)
{
float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
dst_buf.p_data_ = p_src;
}
else
{
const ck::index_t m_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block);
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
ck::index_t i_m_itr = m_per_block;
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block);
i_m_itr -= 8;
p_dst += 8 * k_per_block;
p_src += 8 * C;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
p_dst += 4 * k_per_block;
p_src += 4 * C;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
p_dst += 2 * k_per_block;
p_src += 2 * C;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho;
while(i_m_itr > 0)
{
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
p_dst += k_per_block;
i_wo_itr++;
p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo)
{
i_wo_itr = 0;
i_ho_itr++;
p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{
i_ho_itr = 0;
// i_n++;
p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
}
}
else
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
// c % k_per_block == 0, so every time k_per_block here is the same
ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho;
ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while(i_m_itr > 0)
{
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
else
avx2_util::memset32_avx2(p_dst, 0, k_per_block);
p_dst += k_per_block;
i_wo_itr++;
i_wi_itr += Sx;
p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo)
{
i_wo_itr = 0;
i_wi_itr -= Wo * Sx;
i_ho_itr++;
i_hi_itr += Sy;
p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{
i_ho_itr = 0;
i_hi_itr -= Ho * Sy;
// i_n++;
p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
}
// printf("[%d] \n", __LINE__);
}
else
{
ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho;
ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while(i_m_itr > 0)
{
/*** go along Gemm K ***/
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
ck::index_t i_wi_itr_k = i_wi_itr;
ck::index_t i_hi_itr_k = i_hi_itr;
ck::index_t i_c_itr_k = i_c;
ck::index_t i_y_itr_k = i_y;
ck::index_t i_x_itr_k = i_x;
ck::index_t i_k_itr = k_per_block;
while(i_k_itr > 0)
{
ck::index_t current_k_block = ck::math::min(C - i_c_itr_k, k_per_block);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2(p_dst_k, p_src_k, current_k_block);
else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block);
p_dst_k += current_k_block;
p_src_k += current_k_block;
i_c_itr_k += current_k_block;
if(i_c_itr_k >= C)
{
i_c_itr_k = 0;
i_x_itr_k++;
i_wi_itr_k += Dx;
p_src_k += input_offset_ovf_c_acc_x;
}
if(i_x_itr_k >= Fx)
{
i_x_itr_k = 0;
i_y_itr_k++;
i_hi_itr_k += Dy;
p_src_k += input_offset_ovf_x_acc_y;
}
i_k_itr -= current_k_block;
}
/*** go along Gemm K ***/
p_dst += k_per_block;
i_wo_itr++;
i_wi_itr += Sx;
p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo)
{
i_wo_itr = 0;
i_wi_itr -= Wo * Sx;
i_ho_itr++;
i_hi_itr += Sy;
p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{
i_ho_itr = 0;
i_hi_itr -= Ho * Sy;
// i_n++;
p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
}
}
}
}
}
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
i_c += move_k;
src_offset += move_k;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{
i_c += move_k;
src_offset += move_k;
}
else
{
if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{
// c % k_per_block == 0, so every time k_per_block here is the same
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
i_c += move_k;
src_offset += move_k;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if(i_c >= C)
{
i_c = 0;
i_x++;
i_wi += Dx;
src_offset += Dx * C - C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
if(i_x >= Fx)
{
i_x = 0;
i_y++;
i_wi = i_wi - Fx * Dx;
i_hi += Dy;
src_offset += Dy * Wi * C - Fx * Dx * C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
}
else
{
i_gemm_k += move_k;
i_c = i_gemm_k % C;
i_x = (i_gemm_k / C) % Fx;
i_y = (i_gemm_k / C) / Fx;
i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px;
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
}
}
}
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_n;
ck::index_t i_c;
ck::index_t i_hi;
ck::index_t i_wi;
ck::index_t i_ho;
ck::index_t i_wo;
ck::index_t i_y;
ck::index_t i_x;
ck::index_t i_gemm_k;
ck::index_t N;
// ck::index_t K;
ck::index_t C;
ck::index_t Hi;
ck::index_t Wi;
ck::index_t Ho;
ck::index_t Wo;
ck::index_t Sy;
ck::index_t Sx;
ck::index_t Dy;
ck::index_t Dx;
ck::index_t Py;
ck::index_t Px;
ck::index_t Fy;
ck::index_t Fx;
intptr_t input_offset_acc_wi;
intptr_t input_offset_ovf_wi_acc_hi;
intptr_t input_offset_ovf_hi_acc_n;
// intptr_t input_offset_acc_c;
intptr_t input_offset_ovf_c_acc_x;
intptr_t input_offset_ovf_x_acc_y;
intptr_t src_offset; // keep this as pointer type in case we have negative offset
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
GemmN1 = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<1>{}];
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_n0 = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
ck::index_t idx_n1 = src_slice_origin_idx[Number<2>{}];
i_gemm_n = idx_n0 * GemmN1 + idx_n1;
// i_gemm_k = idx_k;
src_offset = idx_n0 * GemmK * GemmN1 + idx_k + idx_n1 * GemmN1; // Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if constexpr(BypassTransfer)
{
// TODO: weight NHWC not support this
}
else
{
const ck::index_t n_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for(index_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8)
{
ck::index_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), 8);
ck::index_t i_k_itr = k_per_block;
if(current_n_8 == 8)
{
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
while(i_k_itr >= 8)
{
avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK);
p_dst_k += 8 * 8;
p_src_k += 8;
i_k_itr -= 8;
}
if(i_k_itr & 4)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k[2 * 8 + 0] = p_src_k[0 * GemmK + 2];
p_dst_k[2 * 8 + 1] = p_src_k[1 * GemmK + 2];
p_dst_k[2 * 8 + 2] = p_src_k[2 * GemmK + 2];
p_dst_k[2 * 8 + 3] = p_src_k[3 * GemmK + 2];
p_dst_k[2 * 8 + 4] = p_src_k[4 * GemmK + 2];
p_dst_k[2 * 8 + 5] = p_src_k[5 * GemmK + 2];
p_dst_k[2 * 8 + 6] = p_src_k[6 * GemmK + 2];
p_dst_k[2 * 8 + 7] = p_src_k[7 * GemmK + 2];
p_dst_k[3 * 8 + 0] = p_src_k[0 * GemmK + 3];
p_dst_k[3 * 8 + 1] = p_src_k[1 * GemmK + 3];
p_dst_k[3 * 8 + 2] = p_src_k[2 * GemmK + 3];
p_dst_k[3 * 8 + 3] = p_src_k[3 * GemmK + 3];
p_dst_k[3 * 8 + 4] = p_src_k[4 * GemmK + 3];
p_dst_k[3 * 8 + 5] = p_src_k[5 * GemmK + 3];
p_dst_k[3 * 8 + 6] = p_src_k[6 * GemmK + 3];
p_dst_k[3 * 8 + 7] = p_src_k[7 * GemmK + 3];
p_dst_k += 4 * 8;
p_src_k += 4;
}
if(i_k_itr & 2)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k += 2 * 8;
p_src_k += 2;
}
if(i_k_itr & 1)
{
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
}
}
else
{
const float* p_src_k = p_src;
float* p_dst_k = p_dst;
for(index_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
{
for(index_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++)
{
ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
float v =
i_current_n_itr < GemmN ? p_src_k[i_sub_n * GemmK + i_sub_k] : .0f;
p_dst_k[i_sub_k * 8 + i_sub_n] = v;
}
}
}
p_dst += 8 * k_per_block;
p_src += 8 * GemmK;
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
ck::index_t move_n0 = src_slice_origin_step_idx[Number<0>{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset += move_k + move_n0 * GemmK * GemmN1;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_gemm_n;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck::index_t GemmN1;
ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN(
const SrcDesc& src_desc,
const Index&,
const DstDesc& dst_desc,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
DstGemmM = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
DstGemmN = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
src_offset = 0;
dst_offset = 0;
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
if constexpr(BypassTransfer)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
}
}
void SetDstSliceOrigin(const DstDesc&, const Index& dst_slice_origin_idx)
{
i_dst_gemm_m = dst_slice_origin_idx[Number<0>{}];
i_dst_gemm_n = dst_slice_origin_idx[Number<1>{}];
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
}
template <typename SrcBuffer, typename DstBuffer>
void
Run(const SrcDesc& src_desc, SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if constexpr(BypassTransfer)
{
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
}
else
{
const ck::index_t m_per_block =
src_desc.GetTransforms()[Number<0>{}]
.GetUpperLengths()[Number<0>{}]; // must be multiple of 8
const ck::index_t n_per_block =
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
const float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src += 8 * n_per_block;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc&, const Index&) {}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_dst_gemm_m;
ck::index_t i_dst_gemm_n;
ck::index_t DstGemmM;
ck::index_t DstGemmN;
intptr_t src_offset;
intptr_t dst_offset;
};
} // namespace cpu
} // namespace ck
#endif
......@@ -121,7 +121,11 @@ template <typename... Args, typename F>
float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
{
WallTimer timer;
kernel(args...);
int nwarmup = 3;
for(int i = 0; i < nwarmup; i++)
kernel(args...);
timer.Start();
for(int i = 0; i < nrepeat; i++)
......
......@@ -19,7 +19,7 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false;
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType,
......@@ -37,53 +37,37 @@ static constexpr auto ConvFwd1x1P0 =
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
//#################################################################|InDataType|WeiDataType|OutDataType|AccDataType|InElementwiseOp|WeiElementwiseOp|OutElementwiseOp|ConvForwardSp|NumDimSpatial|MPerBlock|NPerBlock|KPerBlock|ThreadwiseGemm_Dispatch
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
256,
128,
64,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
512,
256,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>,
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
float,
float,
float,
float,
PassThrough,
PassThrough,
PassThrough,
ConvFwdDefault,
2,
1024,
144,
128,
ThreadwiseGemmAvx2_MxN_4x24_Dispatch>>;
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
......
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#define AVX2_DATA_ALIGNMENT 32
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
float max_diff = 1e-6;
for(int i = 0; i < ref.mData.size(); ++i)
{
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
return false;
}
}
return true;
}
int main(int argc, char* argv[])
{
int data_type = 0;
int init_method = 0;
// Conv shape
ck::index_t N = 128;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 2;
ck::index_t conv_stride_w = 2;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1)
{
data_type = 1;
init_method = 1;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
}
else if(argc == 18)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]);
conv_stride_h = std::stoi(argv[10]);
conv_stride_w = std::stoi(argv[11]);
conv_dilation_h = std::stoi(argv[12]);
conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
}
else
{
printf("arg1: data type (0=fp32, 1=fp16)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) {
using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using AccDataType = decltype(acc_type);
using ReferenceConvBwdInstance =
ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
auto f_host_tensor_descriptor = [](std::size_t N_,
std::size_t C_,
std::size_t H_,
std::size_t W_) {
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
Tensor<OutDataType> out_n_ho_wo_k(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<WeiDataType> wei_k_y_x_c(f_host_tensor_descriptor(K, C, Y, X));
Tensor<InDataType> in_n_hi_wi_c_host_result(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<InDataType> in_n_hi_wi_c_device_result(f_host_tensor_descriptor(N, C, Hi, Wi));
std::cout << "in (N, C, Hi, Wi): " << in_n_hi_wi_c_host_result.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_y_x_c.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_ho_wo_k.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
case 2:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break;
default:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) *
in_n_hi_wi_c_device_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(
sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data());
// reset input to zero
in_n_hi_wi_c_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_hi_wi_c_device_result.mData.data());
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c_host_result,
wei_k_y_x_c,
out_n_ho_wo_k,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
// profile device Conv instances
bool success = true;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), 1);
in_device_buf.FromDevice(in_n_hi_wi_c_device_result.mData.data());
if(!check_out(in_n_hi_wi_c_host_result, in_n_hi_wi_c_device_result))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(success)
{
std::cout << "test conv2d fwd cpu : Pass" << std::endl;
return 0;
}
else
{
std::cout << "test conv2d fwd cpu: Fail " << std::endl;
return -1;
}
};
if(data_type == 0)
{
return Run(F32(), F32(), F32(), F32());
}
else if(data_type == 1)
{
return Run(F16(), F16(), F16(), F32());
}
else
{
return 1;
}
}
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
using F32 = float;
using F16 = ck::half_t;
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
int error_count = 0;
float max_diff = 1e-6;
for(int i = 0; i < ref.mData.size(); ++i)
{
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
error_count++;
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
i,
double(ref.mData[i]),
double(result.mData[i]),
diff);
}
}
return error_count == 0;
}
float calculate_gflops() {}
int main(int argc, char* argv[])
{
int data_type = 0;
int init_method = 0;
// Conv shape
ck::index_t N = 2;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
ck::index_t Wi = 71;
ck::index_t conv_stride_h = 1;
ck::index_t conv_stride_w = 1;
ck::index_t conv_dilation_h = 1;
ck::index_t conv_dilation_w = 1;
ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1)
{
data_type = 0;
init_method = 1;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
}
else if(argc == 18)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
C = std::stoi(argv[5]);
Y = std::stoi(argv[6]);
X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]);
conv_stride_h = std::stoi(argv[10]);
conv_stride_w = std::stoi(argv[11]);
conv_dilation_h = std::stoi(argv[12]);
conv_dilation_w = std::stoi(argv[13]);
in_left_pad_h = std::stoi(argv[14]);
in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
}
else
{
printf("arg1: data type (0=fp32, 1=fp16)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
using InDataType = decltype(input_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const std::vector<ck::index_t> input_spatial_lengths{{Hi, Wi}};
const std::vector<ck::index_t> filter_spatial_lengths{{Y, X}};
const std::vector<ck::index_t> output_spatial_lengths{{Ho, Wo}};
const std::vector<ck::index_t> conv_filter_strides{{conv_stride_h, conv_stride_w}};
const std::vector<ck::index_t> conv_filter_dilations{{conv_dilation_h, conv_dilation_w}};
const std::vector<ck::index_t> input_left_pads{{in_left_pad_h, in_left_pad_w}};
const std::vector<ck::index_t> input_right_pads{{in_right_pad_h, in_right_pad_w}};
auto f_host_tensor_descriptor = [](std::size_t N_,
std::size_t C_,
std::size_t H_,
std::size_t W_) {
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H_, W_}),
std::vector<std::size_t>({C_ * H_ * W_, 1, W_ * C_, C_}));
};
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
std::cout << "in (N, C, Hi, Wi): " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei(K, C, Y, X): " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out(N, K, Ho, Wo): " << out_n_k_ho_wo_host_result.mDesc << std::endl;
std::cout << "LPad(H, W):" << in_left_pad_h << "," << in_left_pad_w
<< ", RPad(H, W):" << in_right_pad_h << "," << in_right_pad_w
<< ", Stride(H, W):" << conv_stride_h << ", " << conv_stride_w
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break;
case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break;
case 3:
#define PACK_32(v24, v16, v8, v0) \
(((v24 & 0xff) << 24) | ((v16 & 0xff) << 16) | ((v8 & 0xff) << 8) | ((v0 & 0xff) << 0))
for(auto i_n = 0; i_n < N; i_n++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_hi = 0; i_hi < Hi; i_hi++)
{
for(auto i_wi = 0; i_wi < Wi; i_wi++)
{
uint32_t v = PACK_32(i_n, i_c, i_hi, i_wi);
in_n_c_hi_wi(i_n, i_c, i_hi, i_wi) = *reinterpret_cast<float*>(&v);
}
}
}
}
for(auto i_k = 0; i_k < K; i_k++)
{
for(auto i_c = 0; i_c < C; i_c++)
{
for(auto i_y = 0; i_y < Y; i_y++)
{
for(auto i_x = 0; i_x < X; i_x++)
{
uint32_t v = PACK_32(i_k, i_c, i_y, i_x);
wei_k_c_y_x(i_k, i_c, i_y, i_x) = *reinterpret_cast<float*>(&v);
}
}
}
}
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0, 1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1, 1});
}
DeviceAlignedMemCPU in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU wei_device_buf(
sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace(), AVX2_DATA_ALIGNMENT);
DeviceAlignedMemCPU out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_host_result.mDesc.GetElementSpace(),
AVX2_DATA_ALIGNMENT);
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo_host_result,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
}
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>;
// add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
if(conv_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device Conv instance found");
}
// profile device Conv instances
bool success = true;
double fastest_kernel_time = std::numeric_limits<double>::max();
std::string fastest_kernel_name = "";
double fastest_kernel_gflops = 0;
for(auto& conv_ptr : conv_ptrs)
{
auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
double time = invoker_ptr->Run(argument_ptr.get(), 10);
double total_flop = static_cast<double>(2) * N * C * Ho * Wo * K * Y * X;
double gflops = (total_flop * 1e-6) / time;
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false;
}
else
{
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << ", Time:" << time
<< "ms, Gflops:" << gflops << std::endl;
if(time < fastest_kernel_time)
{
fastest_kernel_time = time;
fastest_kernel_name = conv_ptr->GetTypeString();
fastest_kernel_gflops = gflops;
}
}
}
else
{
std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl;
}
}
if(fastest_kernel_time != std::numeric_limits<double>::max())
{
std::cout << " fastest:" << fastest_kernel_name << ", time:" << fastest_kernel_time
<< "ms, Gflops:" << fastest_kernel_gflops << std::endl;
}
return 0;
// if(success)
// {
// std::cout << "test conv2d fwd cpu : Pass" << std::endl;
// return 0;
// }
// else
// {
// std::cout << "test conv2d fwd cpu: Fail " << std::endl;
// return -1;
// }
};
if(data_type == 0)
{
return Run(F32(), F32(), F32());
}
else
{
return 1;
}
}
......@@ -226,6 +226,8 @@ int main(int argc, char** argv)
static constexpr ck::index_t nDim =
ck::remove_reference_t<decltype(input_desc)>::GetNumOfDimension();
input_desc.Print();
auto threadwise_transfer = threadwise_transfer_t{input_desc,
ck::make_zero_multi_index<nDim>(),
input_cblock_desc,
......
......@@ -313,14 +313,15 @@ void test_ukernel(ukenrel_t uk,
float* private_c = mat_c + tid * m * n;
ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a;
param.p_b = mat_b;
param.p_c = private_c;
param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float);
param.alpha = alpha;
param.p_a = mat_a;
param.p_b = mat_b;
param.p_c = private_c;
param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(FloatA);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(FloatB);
param.ldc = n * sizeof(float);
param.alpha = alpha;
param.accmulate_c = 0;
memset(private_c, 0, m * n * sizeof(float));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment