"docs/vscode:/vscode.git/clone" did not exist on "68a35543d5ab91722babf1d26105a5c4eda46a41"
Commit 8ce9fe57 authored by carlushuang's avatar carlushuang
Browse files

remove useless comment, add several new config for multi thread

parent b8ba0239
...@@ -213,9 +213,6 @@ struct BlockwiseGemmAvx2_MxN ...@@ -213,9 +213,6 @@ struct BlockwiseGemmAvx2_MxN
auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread); auto current_mr = ck::math::min(m_per_block - i_m, m_per_thread);
param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)]; param.p_a = &a_block_buf.p_data_[GetABlockStartOffset(a_block_desc, i_m, 0)];
// printf("YYYY: %d, i_m:%d, current_mr:%d, %d, %p\n",__LINE__, i_m, current_mr,
// GetABlockStartOffset(a_block_desc, i_m, 0), param.p_a);fflush(stdout);
for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread) for(ck::index_t i_n = 0; i_n < n_per_block; i_n += n_per_thread)
{ {
auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread); auto current_nr = ck::math::min(n_per_block - i_n, n_per_thread);
...@@ -223,11 +220,6 @@ struct BlockwiseGemmAvx2_MxN ...@@ -223,11 +220,6 @@ struct BlockwiseGemmAvx2_MxN
param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)]; param.p_b = &b_block_buf.p_data_[GetBBlockStartOffset(b_block_desc, 0, i_n)];
param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)]; param.p_c = &c_buf.p_data_[GetCBlockStartOffset(c_desc, i_m, i_n)];
// printf("YYYY: %d, i_n:%d, current_nr:%d, %d, %p, C:%d, %p\n",__LINE__, i_n,
// current_nr, GetBBlockStartOffset(b_block_desc, 0, i_n), param.p_b,
// GetCBlockStartOffset(c_desc, i_m, i_n),
// param.p_c);fflush(stdout);
ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr); ThreadwiseGemm_Dispatch::Run(&param, current_mr, current_nr);
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <utility> #include <utility>
#include <unistd.h> #include <unistd.h>
#include <omp.h> #include <omp.h>
#include <pthread.h>
namespace ck { namespace ck {
namespace cpu { namespace cpu {
...@@ -193,6 +194,23 @@ struct GridwiseGemmAvx2_MxN ...@@ -193,6 +194,23 @@ struct GridwiseGemmAvx2_MxN
int total_threads = omp_get_max_threads(); int total_threads = omp_get_max_threads();
#if 0
if(total_threads > 1){
#pragma omp parallel
{
int tid = omp_get_thread_num();
cpu_set_t set;
CPU_ZERO(&set);
CPU_SET(tid, &set);
if (sched_setaffinity(0, sizeof(set), &set) == -1) {
throw std::runtime_error("wrong! fail to set thread affinity");
}
}
}
#endif
// TODO: openmp aware ordering // TODO: openmp aware ordering
// //
if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value) if constexpr(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 1, 2>>::value)
...@@ -234,7 +252,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -234,7 +252,8 @@ struct GridwiseGemmAvx2_MxN
MemAlignmentByte); MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
MemAlignmentByte); MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC), DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
MemAlignmentByte); MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
...@@ -298,26 +317,9 @@ struct GridwiseGemmAvx2_MxN ...@@ -298,26 +317,9 @@ struct GridwiseGemmAvx2_MxN
auto a_block_desc = GetABlockDescriptor(mc_size, kc_size); auto a_block_desc = GetABlockDescriptor(mc_size, kc_size);
auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size); auto b_block_desc = GetBBlockDescriptor(kc_size, nc_size);
// printf("[tid:%d]==> i_m:%d, i_n:%d, i_k:%d, mc:%d, nc:%d, kc:%d(%d,
// %d)\n", tid, i_mc,
// i_nc, i_kc, mc_size, nc_size, kc_size, KPerBlock, GemmK); fflush(stdout);
a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); a_threadwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); b_threadwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
// for(auto i_elem = 0; i_elem < (mc_size * kc_size) ; i_elem++){
// printf("A ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(a_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(a_block_buf.p_data_))[i_elem]);
//}
// for(auto i_elem = 0; i_elem < (kc_size * nc_size) ; i_elem++){
// printf("B ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(b_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(b_block_buf.p_data_))[i_elem]);
// }
// printf("[%d] 2222 \n",__LINE__);
blockwise_gemm.Run(a_block_desc, blockwise_gemm.Run(a_block_desc,
a_block_buf, a_block_buf,
make_zero_multi_index<a_block_copy_dim>(), make_zero_multi_index<a_block_copy_dim>(),
...@@ -329,28 +331,13 @@ struct GridwiseGemmAvx2_MxN ...@@ -329,28 +331,13 @@ struct GridwiseGemmAvx2_MxN
make_zero_multi_index<2>(), make_zero_multi_index<2>(),
i_kc != 0); i_kc != 0);
// printf("[%d] 2222 \n",__LINE__);
if((i_kc + k_per_block) < GemmK) if((i_kc + k_per_block) < GemmK)
{ {
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step); a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc, a_move_k_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc, b_move_k_step);
} }
// printf("[%d] 2222 \n",__LINE__);
// for(auto i_elem = 0; i_elem < (10) ; i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
} }
// for(auto i_elem = 0; i_elem < (c_block_mem.mMemSize / sizeof(FloatC)) ;
// i_elem++){
// printf("C ==> %3d : %f(0x%08x)\n", i_elem,
// (reinterpret_cast<float*>(c_block_buf.p_data_))[i_elem],
// (reinterpret_cast<uint32_t*>(c_block_buf.p_data_))[i_elem]);
// }
if constexpr(UseCLocalBuffer) if constexpr(UseCLocalBuffer)
c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf); c_threadwise_copy.Run(c_block_desc, c_block_buf, c_grid_desc, c_grid_buf);
} }
...@@ -396,7 +383,8 @@ struct GridwiseGemmAvx2_MxN ...@@ -396,7 +383,8 @@ struct GridwiseGemmAvx2_MxN
MemAlignmentByte); MemAlignmentByte);
DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB), DeviceAlignedMemCPU b_block_mem(k_per_block * n_per_block * sizeof(FloatB),
MemAlignmentByte); MemAlignmentByte);
DeviceAlignedMemCPU c_block_mem(m_per_block * n_per_block * sizeof(FloatC), DeviceAlignedMemCPU c_block_mem(
UseCLocalBuffer ? (m_per_block * n_per_block * sizeof(FloatC)) : 0,
MemAlignmentByte); MemAlignmentByte);
auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>( auto a_block_buf = ck::cpu::make_dynamic_buffer<ck::AddressSpaceEnum::Global>(
......
...@@ -349,9 +349,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -349,9 +349,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c; src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k; i_gemm_k = idx_k;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
} }
} }
...@@ -447,7 +444,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -447,7 +444,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if(i_ho_itr >= Ho) if(i_ho_itr >= Ho)
{ {
i_ho_itr = 0; i_ho_itr = 0;
// i_n++;
p_src += input_offset_ovf_hi_acc_n; p_src += input_offset_ovf_hi_acc_n;
} }
...@@ -468,26 +464,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -468,26 +464,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck::index_t i_wi_itr = i_wi; ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi; ck::index_t i_hi_itr = i_hi;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while(i_m_itr > 0) while(i_m_itr > 0)
{ {
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) && if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi)) (*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block); avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
...@@ -512,14 +490,11 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -512,14 +490,11 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
i_ho_itr = 0; i_ho_itr = 0;
i_hi_itr -= Ho * Sy; i_hi_itr -= Ho * Sy;
// i_n++;
p_src += input_offset_ovf_hi_acc_n; p_src += input_offset_ovf_hi_acc_n;
} }
i_m_itr--; i_m_itr--;
} }
// printf("[%d] \n", __LINE__);
} }
else else
{ {
...@@ -538,7 +513,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -538,7 +513,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck::index_t i_wi_itr_k = i_wi_itr; ck::index_t i_wi_itr_k = i_wi_itr;
ck::index_t i_hi_itr_k = i_hi_itr; ck::index_t i_hi_itr_k = i_hi_itr;
ck::index_t i_c_itr_k = i_c; ck::index_t i_c_itr_k = i_c;
ck::index_t i_y_itr_k = i_y; // ck::index_t i_y_itr_k = i_y;
ck::index_t i_x_itr_k = i_x; ck::index_t i_x_itr_k = i_x;
ck::index_t i_k_itr = k_per_block; ck::index_t i_k_itr = k_per_block;
...@@ -566,7 +541,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -566,7 +541,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if(i_x_itr_k >= Fx) if(i_x_itr_k >= Fx)
{ {
i_x_itr_k = 0; i_x_itr_k = 0;
i_y_itr_k++; // i_y_itr_k++;
i_wi_itr_k -= Dx * Fx; i_wi_itr_k -= Dx * Fx;
i_hi_itr_k += Dy; i_hi_itr_k += Dy;
p_src_k += input_offset_ovf_x_acc_y; p_src_k += input_offset_ovf_x_acc_y;
...@@ -594,7 +569,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -594,7 +569,6 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
i_ho_itr = 0; i_ho_itr = 0;
i_hi_itr -= Ho * Sy; i_hi_itr -= Ho * Sy;
// i_n++;
p_src += input_offset_ovf_hi_acc_n; p_src += input_offset_ovf_hi_acc_n;
} }
...@@ -626,40 +600,27 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC ...@@ -626,40 +600,27 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if constexpr(GemmKSpecialization == if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{ {
// c % k_per_block == 0, so every time k_per_block here is the same
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird // TODO: branch seems weird
i_c += move_k; i_c += move_k;
src_offset += move_k; src_offset += move_k;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if(i_c >= C) if(i_c >= C)
{ {
i_c = 0; i_c = 0;
i_x++; i_x++;
i_wi += Dx; i_wi += Dx;
src_offset += Dx * C - C; src_offset += Dx * C - C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
} }
if(i_x >= Fx) if(i_x >= Fx)
{ {
i_x = 0; i_x = 0;
i_y++; // i_y++;
i_wi = i_wi - Fx * Dx; i_wi = i_wi - Fx * Dx;
i_hi += Dy; i_hi += Dy;
src_offset += Dy * Wi * C - Fx * Dx * C; src_offset += Dy * Wi * C - Fx * Dx * C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
} }
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
} }
else else
{ {
......
...@@ -28,6 +28,12 @@ DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } ...@@ -28,6 +28,12 @@ DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment) DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment)
: mMemSize(mem_size), mAlignment(alignment) : mMemSize(mem_size), mAlignment(alignment)
{ {
if(mem_size == 0)
{
mpDeviceBuf = nullptr;
}
else
{
assert(!(alignment == 0 || (alignment & (alignment - 1)))); // check pow of 2 assert(!(alignment == 0 || (alignment & (alignment - 1)))); // check pow of 2
void* p1; void* p1;
...@@ -39,6 +45,7 @@ DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t align ...@@ -39,6 +45,7 @@ DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t align
p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1)); p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
p2[-1] = p1; p2[-1] = p1;
mpDeviceBuf = reinterpret_cast<void*>(p2); mpDeviceBuf = reinterpret_cast<void*>(p2);
}
} }
void* DeviceAlignedMemCPU::GetDeviceBuffer() { return mpDeviceBuf; } void* DeviceAlignedMemCPU::GetDeviceBuffer() { return mpDeviceBuf; }
...@@ -51,7 +58,11 @@ void DeviceAlignedMemCPU::FromDevice(void* p) { memcpy(p, mpDeviceBuf, mMemSize) ...@@ -51,7 +58,11 @@ void DeviceAlignedMemCPU::FromDevice(void* p) { memcpy(p, mpDeviceBuf, mMemSize)
void DeviceAlignedMemCPU::SetZero() { memset(mpDeviceBuf, 0, mMemSize); } void DeviceAlignedMemCPU::SetZero() { memset(mpDeviceBuf, 0, mMemSize); }
DeviceAlignedMemCPU::~DeviceAlignedMemCPU() { free((reinterpret_cast<void**>(mpDeviceBuf))[-1]); } DeviceAlignedMemCPU::~DeviceAlignedMemCPU()
{
if(mpDeviceBuf != nullptr)
free((reinterpret_cast<void**>(mpDeviceBuf))[-1]);
}
struct KernelTimerImpl struct KernelTimerImpl
{ {
......
...@@ -55,30 +55,81 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver ...@@ -55,30 +55,81 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf> DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple< using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>; DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, false)>;
// clang-format on // clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 48, 24, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, true, true, true),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true, true, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{}); instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
} }
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
......
...@@ -18,6 +18,12 @@ namespace device_conv2d_fwd_avx2_instance { ...@@ -18,6 +18,12 @@ namespace device_conv2d_fwd_avx2_instance {
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
......
...@@ -26,6 +26,12 @@ using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; ...@@ -26,6 +26,12 @@ using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances); std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace device_conv2d_fwd_avx2_instance } // namespace device_conv2d_fwd_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -300,8 +306,22 @@ int main(int argc, char* argv[]) ...@@ -300,8 +306,22 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance:: ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs); add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::device_conv2d_fwd_avx2_instance::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
} }
if(conv_ptrs.size() <= 0) if(conv_ptrs.size() <= 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment