"docs/vscode:/vscode.git/clone" did not exist on "d9f71ab3c3cc162226ec1c9945fef1a5faf4c512"
Commit ad09ebdb authored by carlushuang's avatar carlushuang
Browse files

add kyxck8

parent d6d37ea9
......@@ -484,8 +484,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
while(i_m_itr > 0)
{
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) <
*reinterpret_cast<uint32_t*>(&Hi)) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) <
*reinterpret_cast<uint32_t*>(&Wi)))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block, element_op_);
else
avx2_util::memset32_avx2(p_dst, 0, k_per_block);
......@@ -543,8 +545,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) <
*reinterpret_cast<uint32_t*>(&Hi)) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) <
*reinterpret_cast<uint32_t*>(&Wi)))
avx2_util::memcpy32_avx2(
p_dst_k, p_src_k, current_k_block_along_c, element_op_);
else
......@@ -715,7 +719,7 @@ template <typename SrcData,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
......@@ -723,7 +727,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC(
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
......@@ -927,6 +931,190 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
intptr_t src_offset;
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
GemmN1 =
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; // Need to be 8
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_n0 = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
ck::index_t idx_n1 = src_slice_origin_idx[Number<2>{}];
src_offset = idx_n0 * GemmK * GemmN1 + idx_k * GemmN1 + idx_n1;
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer) {}
else
{
const ck::index_t n0_per_block = slice_length[Number<0>{}];
const ck::index_t k_n1_per_block =
slice_length[Number<1>{}] * slice_length[Number<2>{}];
const ck::index_t SrcStride_K_N1 = GemmK * slice_length[Number<2>{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n0 * k * n1
index_t i_n0_itr = n0_per_block;
while(i_n0_itr >= 8)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block,
p_src + 0 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 1 * k_n1_per_block,
p_src + 1 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 2 * k_n1_per_block,
p_src + 2 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 3 * k_n1_per_block,
p_src + 3 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 4 * k_n1_per_block,
p_src + 4 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 5 * k_n1_per_block,
p_src + 5 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 6 * k_n1_per_block,
p_src + 6 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 7 * k_n1_per_block,
p_src + 7 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
i_n0_itr -= 8;
p_dst += 8 * k_n1_per_block;
p_src += 8 * SrcStride_K_N1;
}
if(i_n0_itr & 4)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block,
p_src + 0 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 1 * k_n1_per_block,
p_src + 1 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 2 * k_n1_per_block,
p_src + 2 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 3 * k_n1_per_block,
p_src + 3 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
p_dst += 4 * k_n1_per_block;
p_src += 4 * SrcStride_K_N1;
}
if(i_n0_itr & 2)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block,
p_src + 0 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
avx2_util::memcpy32_avx2(p_dst + 1 * k_n1_per_block,
p_src + 1 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
p_dst += 2 * k_n1_per_block;
p_src += 2 * SrcStride_K_N1;
}
if(i_n0_itr & 1)
{
avx2_util::memcpy32_avx2(p_dst + 0 * k_n1_per_block,
p_src + 0 * SrcStride_K_N1,
k_n1_per_block,
element_op_);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_n0 = src_slice_origin_step_idx[Number<0>{}];
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
ck::index_t move_n1 = src_slice_origin_step_idx[Number<2>{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset += move_n0 * GemmK * GemmN1 + move_k * GemmN1 + move_n1;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_gemm_n;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck::index_t GemmN1;
ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
};
template <typename SrcData,
typename DstData,
typename SrcDesc,
......
# device_conv2d_fwd_cpu_instance
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
)
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXCK8
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
......@@ -33,6 +33,24 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......
......@@ -16,7 +16,11 @@
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
#define TEST_FUSION TEST_FUSION_PASSTHROUGH
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using F32 = float;
using F16 = ck::half_t;
......@@ -48,6 +52,24 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, Relu>>& instances);
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
......@@ -115,6 +137,31 @@ check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pi
float calculate_gflops() {}
template <typename T>
void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = K / 8;
ck::index_t row = 8;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[])
{
int data_type = 0;
......@@ -213,6 +260,10 @@ int main(int argc, char* argv[])
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X));
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
......@@ -296,8 +347,13 @@ int main(int argc, char* argv[])
AVX2_DATA_ALIGNMENT);
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
// get host result
{
auto ref_conv = ReferenceConvFwdInstance{};
......@@ -334,6 +390,7 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
......@@ -369,6 +426,45 @@ int main(int argc, char* argv[])
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(conv_ptrs);
}
#endif
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment