Commit f9cf57d4 authored by carlushuang's avatar carlushuang
Browse files

support YXCK filter

parent 71254ddd
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK #define TEST_LAYOUT_NHWC_YXCK_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXC_NHWK
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance { ...@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
// ------------------ nhwc-kyxc-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
...@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt( ...@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
// ------------------ nhwc-kcyxk8-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
...@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt( ...@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
// ------------------ nhwc-yxck-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst, ...@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
} }
} }
template <typename T>
void transpose_kyxc_2_yxck(Tensor<T>& dst,
const Tensor<T>& src,
ck::index_t K,
ck::index_t Y,
ck::index_t X,
ck::index_t C)
{
ck::index_t batch = 1;
ck::index_t row = K;
ck::index_t col = C * Y * X;
for(auto i_b = 0; i_b < batch; i_b++)
{
for(auto i_r = 0; i_r < row; i_r++)
{
for(auto i_c = 0; i_c < col; i_c++)
{
ck::index_t src_idx = i_b * row * col + i_r * col + i_c;
ck::index_t dst_idx = i_b * col * row + i_c * row + i_r;
dst.mData[dst_idx] = src.mData[src_idx];
}
}
}
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
int data_type = 0; int data_type = 0;
...@@ -243,6 +284,10 @@ int main(int argc, char* argv[]) ...@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor<WeiDataType> wei_k_c_y_x_k8( Tensor<WeiDataType> wei_k_c_y_x_k8(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor<WeiDataType> wei_y_x_c_k(
f_host_tensor_descriptor(K, C, Y, X)); // TODO: This is only to hold data
#endif #endif
Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo)); Tensor<OutDataType> out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo));
Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo)); Tensor<OutDataType> out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo));
...@@ -319,6 +364,10 @@ int main(int argc, char* argv[]) ...@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C); transpose_kyxc_2_kyxc8k(wei_k_c_y_x_k8, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x_k8.mData.data());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck(wei_y_x_c_k, wei_k_c_y_x, K, Y, X, C);
wei_device_buf.ToDevice(wei_y_x_c_k.mData.data());
#endif #endif
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.mData.data());
resi_device_buf.ToDevice(residual.mData.data()); resi_device_buf.ToDevice(residual.mData.data());
...@@ -404,6 +453,30 @@ int main(int argc, char* argv[]) ...@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
conv_ptrs);
}
#endif #endif
} }
......
...@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN ...@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB); auto ldb = GetBLeadingElement(b_block_desc) * sizeof(FloatB);
auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC); auto ldc = GetCLeadingElement(c_desc) * sizeof(FloatC);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const auto k_per_block = a_slice_length[Number<1>{}]; const auto k_per_block = a_slice_length[Number<1>{}];
const auto m_per_block = c_slice_length[Number<0>{}]; const auto m_per_block = c_slice_length[Number<0>{}];
const auto n_per_block = c_slice_length[Number<1>{}]; const auto n_per_block = c_slice_length[Number<1>{}];
...@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN ...@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
param.alpha = 1.0f; // TODO param.alpha = 1.0f; // TODO
param.accmulate_c = is_accumulate_c ? 1 : 0; param.accmulate_c = is_accumulate_c ? 1 : 0;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc, // printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u, mpt:%u, npt:%u\n",
// m_per_block, n_per_block, k_per_block); // lda,
// ldb,
// ldc,
// m_per_block,
// n_per_block,
// k_per_block,
// m_per_thread,
// n_per_thread);
// fflush(stdout);
if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value) if constexpr(std::is_same<ThreadMNAccessOrder, ck::Sequence<0, 1>>::value)
{ {
......
...@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"movq (%[m_param]), %%rax\n" // p_a "movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b "movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr "movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda "movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb "movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
...@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_ABytes == 4\n" ".if m_ABytes == 4\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, (\\i_k-3), (\\i_m * m_ABytes), \\ymm\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n" ".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n" "vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n"
...@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vpbroadcastw_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), %%xmm15\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%rax, %%rcx, (\\i_k-3), (\\i_m * m_ABytes), %%xmm15\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n" ".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n" "vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n"
...@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1 ".macro vload_b%= i_k, i_n, ymm\n" // B in rbx(r9), lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n" ".if m_BBytes == 4\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vmovups_%= %%rbx, %%rdx, \\i_k, (\\i_n*m_BBytes*8), \\ymm\n"
".else\n"
"vmovups_%= %%r9, %%rdx, (\\i_k-3), (\\i_n*m_BBytes*8), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vcvtph2ps_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1) || (\\i_k == 2)\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_k, (\\i_n*m_BBytes*8), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%r9, %%rdx, (\\i_k-3), (\\i_n*m_BBytes*8), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"lea (%%rcx, %%rcx, 2), %%r9\n" "lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n" "lea (%%rax, %%r9), %%r8\n"
".endif\n" ".endif\n"
".else\n"
"lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
"lea (%%rdx, %%rdx, 2), %%rdi\n"
"lea (%%rbx, %%rdi), %%r9\n"
".endif\n" ".endif\n"
"cmp $4, %%rsi\n" "cmp $4, %%rsi\n"
...@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea 4*m_ABytes(%%rax), %%rax\n" " lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 4), %%rax\n"
" lea (%%r8, %%rcx, 4), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 4), %%rbx\n"
" lea (%%r9, %%rdx, 4), %%r9\n"
".else\n" ".else\n"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n" " lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea m_ABytes(%%rax), %%rax\n" " lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 1), %%rax\n"
" lea (%%r8, %%rcx, 1), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 1), %%rbx\n"
" lea (%%r9, %%rdx, 1), %%r9\n"
".else\n" ".else\n"
" lea 8*m_BBytes(%%rbx), %%rbx\n" " lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
else else
{ {
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m); ymm = _mm256_broadcast_ss(p_a + i_k * lda + i_m);
} }
} }
else else
...@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
else else
{ {
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m))); ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * lda + i_m)));
} }
} }
}; };
...@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
{ {
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8); ymm = _mm256_loadu_ps(p_b + i_k * ldb + i_n * 8);
} }
else else
{ {
...@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_cvtph_ps(_mm_loadu_si128( ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8))); reinterpret_cast<__m128i const*>(p_b + i_k * ldb + i_n * 8)));
} }
else else
{ {
...@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4; p_a += 4;
} else{ } else{
p_a += Mr * 4; p_a += lda * 4;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4; p_b += ldb * 4;
}else{ }else{
p_b += 4 * 8; p_b += 4 * 8;
} }
...@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1; p_a += 1;
} else{ } else{
p_a += Mr * 1; p_a += lda * 1;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1; p_b += ldb * 1;
}else{ }else{
p_b += 1 * 8; p_b += 1 * 8;
} }
...@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"movq (%[m_param]), %%rax\n" // p_a "movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b "movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr "movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda "movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb "movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
...@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_ABytes == 4\n" ".if m_ABytes == 4\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, (\\i_k-2), (\\i_m * m_ABytes), \\ymm\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n" ".if (\\i_m == 0) || (\\i_m == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n" "vbroadcastss_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), \\ymm\n"
...@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vpbroadcastw_%= %%rax, 0, 0, ((\\i_m + \\i_k * m_Mr) * m_ABytes), %%xmm15\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_k, (\\i_m * m_ABytes), %%xmm15\n"
".else\n"
"vpbroadcastw_%= %%r8, %%rcx, (\\i_k-2), (\\i_m * m_ABytes), %%xmm15\n"
".endif\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n" ".if (\\i_m == 0) || (\\i_m == 1)\n"
"vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n" "vpbroadcastw_%= %%rax, %%rcx, \\i_m, (\\i_k * m_ABytes), %%xmm15\n"
...@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vmovups_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vmovups_%= %%rbx, %%rdx, \\i_k, (\\i_n*8*m_BBytes), \\ymm\n"
".else\n"
"vmovups_%= %%rdi, %%rdx, (\\i_k-2), (\\i_n*8*m_BBytes), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n" "vcvtph2ps_%= %%rbx, %%rdx, \\i_n, (\\i_k*m_BBytes*8), \\ymm\n"
".else\n" ".else\n"
"vcvtph2ps_%= %%rbx, 0, 0, ((\\i_k*m_Nr + \\i_n*8)*m_BBytes), \\ymm\n" ".if (\\i_k == 0) || (\\i_k == 1)\n"
"vcvtph2ps_%= %%rbx, %%rdx, \\i_k, (\\i_n*8*m_BBytes), \\ymm\n"
".else\n"
"vcvtph2ps_%= %%rdi, %%rdx, (\\i_k-2), (\\i_n*8*m_BBytes), \\ymm\n"
".endif\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_Mr > 2\n" ".if m_Mr > 2\n"
"lea (%%rax, %%rcx, 2), %%r8\n" "lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n" ".endif\n"
".else\n"
"lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n"
".if m_TransB != 0\n"
"lea (%%rbx, %%rdx, 2), %%rdi\n"
".endif\n" ".endif\n"
"cmp $4, %%rsi\n" "cmp $4, %%rsi\n"
...@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea 4*m_ABytes(%%rax), %%rax\n" " lea 4*m_ABytes(%%rax), %%rax\n"
".if m_Mr > 2\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 2\n lea 4*m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 4), %%rax\n"
" lea (%%r8, %%rcx, 4), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 4), %%rbx\n"
" lea (%%rdi, %%rdx, 4), %%rdi\n"
".else\n" ".else\n"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n" " lea 8 * 4 * m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea m_ABytes(%%rax), %%rax\n" " lea m_ABytes(%%rax), %%rax\n"
".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea m_ABytes(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * m_ABytes(%%rax), %%rax\n" " lea (%%rax, %%rcx, 1), %%rax\n"
" lea (%%r8, %%rcx, 1), %%r8\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * m_BBytes(%%rbx), %%rbx\n" " lea (%%rbx, %%rdx, 1), %%rbx\n"
" lea (%%rdi, %%rdx, 1), %%rdi\n"
".else\n" ".else\n"
" lea 8*m_BBytes(%%rbx), %%rbx\n" " lea 8*m_BBytes(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
...@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
} }
else else
{ {
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m); ymm = _mm256_broadcast_ss(p_a + i_k * lda + i_m);
} }
} }
else else
...@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
} }
else else
{ {
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m))); ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * lda + i_m)));
} }
} }
}; };
...@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
{ {
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8); ymm = _mm256_loadu_ps(p_b + i_k * ldb + i_n * 8);
} }
else else
{ {
...@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
ymm = _mm256_cvtph_ps(_mm_loadu_si128( ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8))); reinterpret_cast<__m128i const*>(p_b + i_k * ldb + i_n * 8)));
} }
else else
{ {
...@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4; p_a += 4;
} else{ } else{
p_a += Mr * 4; p_a += lda * 4;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4; p_b += ldb * 4;
}else{ }else{
p_b += 4 * 8; p_b += 4 * 8;
} }
...@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1; p_a += 1;
} else{ } else{
p_a += Mr * 1; p_a += lda * 1;
} }
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){ if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1; p_b += ldb * 1;
}else{ }else{
p_b += 1 * 8; p_b += 1 * 8;
} }
......
...@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8 ...@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
intptr_t src_offset; intptr_t src_offset;
}; };
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK(
const SrcDesc& src_desc,
const Index&,
const DstDesc&,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
ck::index_t idx_k = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_n = src_slice_origin_idx[Number<1>{}];
src_offset = idx_k * GemmN + idx_n;
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer, typename SliceLengths>
void RunRead(const SrcDesc&,
SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
else
{
const ck::index_t k_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// k * n
index_t i_k_itr = k_per_block;
while(i_k_itr >= 8)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * n_per_block, p_src + 2 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * n_per_block, p_src + 3 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 4 * n_per_block, p_src + 4 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 5 * n_per_block, p_src + 5 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 6 * n_per_block, p_src + 6 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 7 * n_per_block, p_src + 7 * GemmN, n_per_block, element_op_);
i_k_itr -= 8;
p_dst += 8 * n_per_block;
p_src += 8 * GemmN;
}
if(i_k_itr & 4)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 2 * n_per_block, p_src + 2 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 3 * n_per_block, p_src + 3 * GemmN, n_per_block, element_op_);
p_dst += 4 * n_per_block;
p_src += 4 * GemmN;
}
if(i_k_itr & 2)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
avx2_util::memcpy32_avx2(
p_dst + 1 * n_per_block, p_src + 1 * GemmN, n_per_block, element_op_);
p_dst += 2 * n_per_block;
p_src += 2 * GemmN;
}
if(i_k_itr & 1)
{
avx2_util::memcpy32_avx2(
p_dst + 0 * n_per_block, p_src + 0 * GemmN, n_per_block, element_op_);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{
ck::index_t move_k = src_slice_origin_step_idx[Number<0>{}];
ck::index_t move_n = src_slice_origin_step_idx[Number<1>{}];
src_offset += move_k * GemmN + move_n;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
};
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
) )
add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC) target_compile_features(device_conv2d_fwd_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using Relu = ck::tensor_operation::cpu::element_wise::Relu;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 1024, 416, 128, 6, 16, true)>;
// clang-format on
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 24, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 32, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 40, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 48, 48, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 56, 24, 256, 4, 24, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 16, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 72, 32, 256, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 96, 64, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 32, 128, 6, 16, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 120, 64, 128, 6, 16, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 256, 128, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 128, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 240, 128, 4, 24, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 512, 256, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 768, 320, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 896, 352, 128, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Relu, 1024, 416, 128, 6, 16, true)>;
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances{});
}
void add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu(
std::vector<DeviceConvFwdPtr<PT, PT, Relu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances{});
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
) )
add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_activation_add_cpu_instance SHARED ${DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC)
......
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace cpu {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using InType = float;
using WeiType = float;
using OutType = float;
using AccType = float;
static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
// clang-format on
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false)>;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false)>;
// clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances{});
}
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances, device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device
} // namespace cpu
} // namespace tensor_operation
} // namespace ck
...@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk, ...@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
int max_threads = omp_get_max_threads(); int max_threads = omp_get_max_threads();
auto invoke_uk = [&](ck::cpu::ThreadwiseGemmParam& param, float* current_mat_c) { auto invoke_uk = [&](ck::cpu::ThreadwiseGemmParam& param, float* current_mat_c) {
if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value) assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{ {
assert(m % uk.ThreadMr == 0 && n == uk.ThreadNr); if constexpr(std::is_same<Row, ALayout>::value)
FloatA* p_a = mat_a;
float* p_c = current_mat_c;
param.p_a = p_a;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{ {
uk.Run(&param); param.p_a = mat_a + i_m * k;
p_a += uk.ThreadMr * k;
p_c += uk.ThreadMr * n;
param.p_a = p_a;
param.p_c = p_c;
} }
} else
else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
FloatA* p_a = mat_a;
float* p_c = current_mat_c;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.ThreadMr)
{ {
float* p_c_n = p_c; param.p_a = mat_a + i_m;
FloatB* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr)
{
uk.Run(&param);
p_b_n += uk.ThreadNr * k; // ThreadNr/8*k*8
p_c_n += uk.ThreadNr;
param.p_b = p_b_n;
param.p_c = p_c_n;
}
p_a += uk.ThreadMr * k;
p_c += uk.ThreadMr * n;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
} }
}
else if constexpr(std::is_same<Col, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m == uk.ThreadMr && n == uk.ThreadNr);
uk.Run(&param);
}
else
{
assert(m % uk.ThreadMr == 0 && n % uk.ThreadNr == 0);
FloatB* p_b = mat_b;
float* p_c = current_mat_c;
param.p_b = p_b;
param.p_c = p_c;
for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr) for(uint32_t i_n = 0; i_n < n; i_n += uk.ThreadNr)
{ {
if constexpr(std::is_same<Row, BLayout>::value)
{
param.p_b = mat_b + i_n;
}
else
{
param.p_b = mat_b + i_n * k;
}
param.p_c = current_mat_c + i_m * n + i_n;
uk.Run(&param); uk.Run(&param);
p_b += uk.ThreadNr * k; // ThreadNr/8*k*8
p_c += uk.ThreadNr;
param.p_b = p_b;
param.p_c = p_c;
} }
} }
}; };
...@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk, ...@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
} }
// implement small ukernel on L1 // implement small ukernel on L1
template <typename FloatA, typename FloatB, typename ALayout, typename BLayout> template <typename FloatA,
typename FloatB,
typename ALayout,
typename BLayout,
typename thread_gemm_instance>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{ {
int max_threads = omp_get_max_threads(); int max_threads = omp_get_max_threads();
...@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
k); k);
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>; // using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>; // using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool found = false; bool found = false;
ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) { ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_instance>; using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
if(m % uk_type::ThreadMr != 0 || n % uk_type::ThreadNr != 0) if(m % uk_type::ThreadMr != 0 || n % uk_type::ThreadNr != 0)
return; return;
if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value) || // if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value)
(n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value)) // ||
// only k is the fast changing dim of A/B can we do muldiplt m, n // (n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value))
return; // // only k is the fast changing dim of A/B can we do muldiplt m, n
// return;
if(found) if(found)
return; return;
...@@ -435,8 +402,21 @@ int main(int argc, char** argv) ...@@ -435,8 +402,21 @@ int main(int argc, char** argv)
omp_set_num_threads(1); omp_set_num_threads(1);
printf("max threads:%d\n", omp_get_max_threads()); printf("max threads:%d\n", omp_get_max_threads());
test_cpu_ukernel<AType, BType, Row, Row>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Row, Row, thread_gemm_avx2_mxn_4x24_instances<Row, Row>>(
test_cpu_ukernel<AType, BType, Row, Col>(alpha, m, n, k); alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row>(alpha, m, n, k); test_cpu_ukernel<AType, BType, Row, Col, thread_gemm_avx2_mxn_4x24_instances<Row, Col>>(
test_cpu_ukernel<AType, BType, Col, Col>(alpha, m, n, k); alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row, thread_gemm_avx2_mxn_4x24_instances<Col, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col, thread_gemm_avx2_mxn_4x24_instances<Col, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Row, thread_gemm_avx2_mxn_6x16_instances<Row, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Row, Col, thread_gemm_avx2_mxn_6x16_instances<Row, Col>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Row, thread_gemm_avx2_mxn_6x16_instances<Col, Row>>(
alpha, m, n, k);
test_cpu_ukernel<AType, BType, Col, Col, thread_gemm_avx2_mxn_6x16_instances<Col, Col>>(
alpha, m, n, k);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment