"docs/vscode:/vscode.git/clone" did not exist on "2c915218e85c81558b66cff23ef92e646e6442f7"
Commit f8b551da authored by carlushuang's avatar carlushuang
Browse files

add bias_relu, bias fusion

parent bfa4c686
...@@ -8,12 +8,18 @@ ...@@ -8,12 +8,18 @@
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd_bias_activation_add.hpp" #include "reference_conv_fwd_bias_activation_add.hpp"
#include "reference_conv_fwd_bias_activation.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <omp.h> #include <omp.h>
#define AVX2_DATA_ALIGNMENT 32 #define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_BIAS_RELU_ADD 0
#define TEST_FUSION_BIAS_RELU 1
#define TEST_FUSION_BIAS 2
#define TEST_FUSION TEST_FUSION_BIAS
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0 #define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1 #define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2 #define TEST_LAYOUT_NHWC_YXCK_NHWK 2
...@@ -30,46 +36,102 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance { ...@@ -30,46 +36,102 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
// ------------------ nhwc-kyxc-nhwk // ------------------ nhwc-kyxc-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
// ------------------ nhwc-kcyxk8-nhwk // ------------------ nhwc-kcyxk8-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
// ------------------ nhwc-yxck-nhwk // ------------------ nhwc-yxck-nhwk
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>&
instances); instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
void add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>>& instances);
} // namespace device_conv2d_fwd_bias_activation_add_avx2_instance } // namespace device_conv2d_fwd_bias_activation_add_avx2_instance
} // namespace device } // namespace device
} // namespace cpu } // namespace cpu
...@@ -78,7 +140,13 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt( ...@@ -78,7 +140,13 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using InElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
using OutElementOp = ck::tensor_operation::cpu::element_wise::AddReluAdd; using OutElementOp = ck::tensor_operation::cpu::element_wise::AddReluAdd;
#elif TEST_FUSION == TEST_FUSION_BIAS_RELU
using OutElementOp = ck::tensor_operation::cpu::element_wise::AddRelu;
#elif TEST_FUSION == TEST_FUSION_BIAS
using OutElementOp = ck::tensor_operation::cpu::element_wise::Add;
#endif
template <typename T> template <typename T>
static bool static bool
...@@ -249,6 +317,7 @@ int main(int argc, char* argv[]) ...@@ -249,6 +317,7 @@ int main(int argc, char* argv[])
using WeiDataType = decltype(wei_type); using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type); using OutDataType = decltype(out_type);
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
using ReferenceConvFwdInstance = using ReferenceConvFwdInstance =
ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add<InDataType, ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add<InDataType,
WeiDataType, WeiDataType,
...@@ -256,6 +325,15 @@ int main(int argc, char* argv[]) ...@@ -256,6 +325,15 @@ int main(int argc, char* argv[])
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>; OutElementOp>;
#else
using ReferenceConvFwdInstance =
ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
#endif
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
...@@ -381,7 +459,9 @@ int main(int argc, char* argv[]) ...@@ -381,7 +459,9 @@ int main(int argc, char* argv[])
wei_k_c_y_x, wei_k_c_y_x,
out_n_k_ho_wo_host_result, out_n_k_ho_wo_host_result,
bias, bias,
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
residual, residual,
#endif
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -394,9 +474,19 @@ int main(int argc, char* argv[]) ...@@ -394,9 +474,19 @@ int main(int argc, char* argv[])
using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough; using PassThrough = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device:: using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>; DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>;
#elif TEST_FUSION == TEST_FUSION_BIAS_RELU
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddRelu>;
#elif TEST_FUSION == TEST_FUSION_BIAS
using DeviceConvFwdNoOpPtr = ck::tensor_operation::cpu::device::
DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, Add>;
#endif
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
...@@ -405,27 +495,27 @@ int main(int argc, char* argv[]) ...@@ -405,27 +495,27 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK #if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if(omp_get_max_threads() > 1) if(omp_get_max_threads() > 1)
{ {
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs); add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk(conv_ptrs); add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk(conv_ptrs);
} }
else else
{ {
if(K % 8 == 0) if(K % 8 == 0)
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk(conv_ptrs);
conv_ptrs);
else else
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
#endif #endif
...@@ -434,23 +524,21 @@ int main(int argc, char* argv[]) ...@@ -434,23 +524,21 @@ int main(int argc, char* argv[])
{ {
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
conv_ptrs);
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk(conv_ptrs); add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
} }
else else
{ {
if(K % 8 == 0) if(K % 8 == 0)
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
conv_ptrs);
else else
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
#endif #endif
...@@ -459,24 +547,159 @@ int main(int argc, char* argv[]) ...@@ -459,24 +547,159 @@ int main(int argc, char* argv[])
{ {
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(conv_ptrs); add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk(conv_ptrs); add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk(conv_ptrs);
} }
else else
{ {
if(K % 8 == 0) if(K % 8 == 0)
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk( add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_local_c(
conv_ptrs); conv_ptrs);
}
#endif
#elif TEST_FUSION == TEST_FUSION_BIAS_RELU
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
}
else else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device:: ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance:: device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c( add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_local_c(
conv_ptrs); conv_ptrs);
} }
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_local_c(conv_ptrs);
}
#endif
#elif TEST_FUSION == TEST_FUSION_BIAS
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_local_c(conv_ptrs);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if(omp_get_max_threads() > 1)
{
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_mt(conv_ptrs);
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk(conv_ptrs);
}
else
{
if(K % 8 == 0)
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk(conv_ptrs);
else
ck::tensor_operation::cpu::device::
device_conv2d_fwd_bias_activation_add_avx2_instance::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_local_c(conv_ptrs);
}
#endif
#endif #endif
} }
......
...@@ -37,6 +37,8 @@ template <typename InDataType, ...@@ -37,6 +37,8 @@ template <typename InDataType,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -607,8 +609,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -607,8 +609,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization>; ConvForwardSpecialization>;
using CThreadwiseCopy = static constexpr auto GetCThreadwiseCopy()
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< {
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType, CDataType,
C0DataType, C0DataType,
C1DataType, C1DataType,
...@@ -619,7 +626,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu ...@@ -619,7 +626,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
BiasAlongGemmM>; BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType, ADataType, // InDataType,
......
...@@ -37,6 +37,8 @@ template <typename InDataType, ...@@ -37,6 +37,8 @@ template <typename InDataType,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -584,8 +586,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -584,8 +586,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization>; ConvForwardSpecialization>;
using CThreadwiseCopy = static constexpr auto GetCThreadwiseCopy()
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< {
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType, CDataType,
C0DataType, C0DataType,
C1DataType, C1DataType,
...@@ -596,7 +603,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou ...@@ -596,7 +603,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
BiasAlongGemmM>; BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType, ADataType, // InDataType,
......
...@@ -36,6 +36,8 @@ template <typename InDataType, ...@@ -36,6 +36,8 @@ template <typename InDataType,
bool UseALocalBuffer, bool UseALocalBuffer,
bool UseBLocalBuffer, bool UseBLocalBuffer,
bool UseCLocalBuffer, bool UseCLocalBuffer,
bool FuseBias,
bool FuseAdd,
bool BiasAlongGemmM> bool BiasAlongGemmM>
struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation, : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
...@@ -580,8 +582,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -580,8 +582,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
!UseBLocalBuffer, !UseBLocalBuffer,
ConvForwardSpecialization>; ConvForwardSpecialization>;
using CThreadwiseCopy = static constexpr auto GetCThreadwiseCopy()
ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN< {
constexpr ck::index_t C_nDim = CGridDesc::GetNumOfDimension();
if constexpr(FuseBias && FuseAdd)
{
return ck::cpu::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN<
CDataType, CDataType,
C0DataType, C0DataType,
C1DataType, C1DataType,
...@@ -592,7 +599,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu ...@@ -592,7 +599,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
decltype(GetOutputBlockDescriptor()), decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation, OutElementwiseOperation,
!UseCLocalBuffer, !UseCLocalBuffer,
BiasAlongGemmM>; BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
else if constexpr(FuseBias && !FuseAdd)
{
return ck::cpu::ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN<
CDataType,
C0DataType,
C1DataType,
CDataType,
CGridDesc,
C0GridDesc,
C1GridDesc,
decltype(GetOutputBlockDescriptor()),
OutElementwiseOperation,
!UseCLocalBuffer,
BiasAlongGemmM>(CGridDesc{},
ck::make_zero_multi_index<C_nDim>(),
GetOutputBlockDescriptor(),
ck::make_zero_multi_index<C_nDim>(),
OutElementwiseOperation{});
}
}
using CThreadwiseCopy = decltype(GetCThreadwiseCopy());
using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN< using GridwiseGemm = ck::cpu::GridwiseGemmBiasActivationAddAvx2_MxN<
ADataType, // InDataType, ADataType, // InDataType,
......
...@@ -273,6 +273,53 @@ void memcpy32_avx2_with_extra_1src(void* dst, ...@@ -273,6 +273,53 @@ void memcpy32_avx2_with_extra_1src(void* dst,
} }
} }
template <typename ElementwiseOp>
void memcpy32_avx2_with_extra_1src(void* dst,
const void* src,
const float v_src_aux,
const ck::index_t n,
const ElementwiseOp& element_op)
{
// 16-8-4-2-1 pattern
ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src);
__m256 ymm_src_aux = _mm256_set1_ps(*reinterpret_cast<const float*>(&v_src_aux));
__m128 xmm_src_aux = _mm_set1_ps(*reinterpret_cast<const float*>(&v_src_aux));
while(i_n >= 16)
{
_mm256_storeu_ps(p_dst + 0, element_op.Apply(_mm256_loadu_ps(p_src + 0), ymm_src_aux));
_mm256_storeu_ps(p_dst + 8, element_op.Apply(_mm256_loadu_ps(p_src + 8), ymm_src_aux));
p_dst += 16;
p_src += 16;
i_n -= 16;
}
if(i_n & 8)
{
_mm256_storeu_ps(p_dst, element_op.Apply(_mm256_loadu_ps(p_src), ymm_src_aux));
p_dst += 8;
p_src += 8;
}
if(i_n & 4)
{
_mm_storeu_ps(p_dst, element_op.Apply(_mm_loadu_ps(p_src), xmm_src_aux));
p_dst += 4;
p_src += 4;
}
if(i_n & 2)
{
_mm_storeu_si64(p_dst, element_op.Apply(_mm_loadu_si64(p_src), xmm_src_aux));
p_dst += 2;
p_src += 2;
}
if(i_n & 1)
{
*p_dst = element_op.Apply(*p_src, v_src_aux);
}
}
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n) inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
...@@ -2369,6 +2416,582 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_ ...@@ -2369,6 +2416,582 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
intptr_t dst_offset; intptr_t dst_offset;
}; };
template <typename SrcData,
typename Src1Data, // for Bias, per dimension
typename Src2Data, // for Residual, per pixel
typename DstData,
typename SrcDesc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename ElementwiseOperation,
bool BypassTransfer,
bool Src1AlongDim0> // if true, src1 has dim along M, false, src1 has dim along N
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
{
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN(
const SrcDesc& src_desc,
const Index&,
const DstDesc& dst_desc,
const Index&,
const ElementwiseOperation& element_op)
: element_op_(element_op)
{
DstGemmM = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
DstGemmN = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
src_offset = 0;
src1_offset = 0;
dst_offset = 0;
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
if constexpr(BypassTransfer)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
}
}
void SetSrc1SliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{
if constexpr(Src1AlongDim0)
{
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
// auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src1_offset = i_src_gemm_m;
}
else
{
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src1_offset = i_src_gemm_n;
}
}
void SetSrc2SliceOrigin(const SrcDesc&, const Index&) {}
void SetDstSliceOrigin(const DstDesc&, const Index& dst_slice_origin_idx)
{
i_dst_gemm_m = dst_slice_origin_idx[Number<0>{}];
i_dst_gemm_n = dst_slice_origin_idx[Number<1>{}];
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
}
template <typename SrcBuffer,
typename Src1Buffer,
typename Src2Buffer,
typename DstBuffer,
typename SliceLengths>
void RunRead(const SrcDesc&,
SrcBuffer& src_buf,
const Src1Desc&,
Src1Buffer&,
const Src2Desc&,
Src2Buffer&,
const DstDesc&,
DstBuffer& dst_buf,
const SliceLengths&)
{
if constexpr(BypassTransfer)
{
dst_buf.p_data_ = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
}
}
template <typename SrcBuffer,
typename Src1Buffer,
typename Src2Buffer,
typename DstBuffer,
typename SliceLengths>
void RunWrite(const SrcDesc& src_desc,
SrcBuffer& src_buf,
const Src1Desc& src1_desc,
Src1Buffer& src1_buf,
const Src2Desc&,
Src2Buffer&,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
const SliceLengths& slice_length)
{
if constexpr(BypassTransfer)
{
if constexpr(!std::is_same<ElementwiseOperation,
ck::tensor_operation::cpu::element_wise::PassThrough>::value)
{
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
const float* p_src1 =
reinterpret_cast<const float*>(src1_buf.p_data_) + src1_offset;
ck::index_t i_m_itr = m_per_block;
// standard 8-4-2-1 pattern
if constexpr(Src1AlongDim0)
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
*(p_src1 + 1),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
*(p_src1 + 2),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
*(p_src1 + 3),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 4 * DstGemmN,
p_dst + 4 * DstGemmN,
*(p_src1 + 4),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 5 * DstGemmN,
p_dst + 5 * DstGemmN,
*(p_src1 + 5),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 6 * DstGemmN,
p_dst + 6 * DstGemmN,
*(p_src1 + 6),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 7 * DstGemmN,
p_dst + 7 * DstGemmN,
*(p_src1 + 7),
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src1 += 8;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
*(p_src1 + 1),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
*(p_src1 + 2),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
*(p_src1 + 3),
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src1 += 4;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
*(p_src1 + 1),
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src1 += 2;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
*(p_src1 + 0),
current_n,
element_op_);
}
}
else
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 4 * DstGemmN,
p_dst + 4 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 5 * DstGemmN,
p_dst + 5 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 6 * DstGemmN,
p_dst + 6 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 7 * DstGemmN,
p_dst + 7 * DstGemmN,
p_src1,
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_dst + 2 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_dst + 3 * DstGemmN,
p_src1,
current_n,
element_op_);
p_dst += 4 * DstGemmN;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_dst + 1 * DstGemmN,
p_src1,
current_n,
element_op_);
p_dst += 2 * DstGemmN;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_dst + 0 * DstGemmN,
p_src1,
current_n,
element_op_);
}
}
}
}
else
{
const ck::index_t m_per_block = slice_length[Number<0>{}];
const ck::index_t n_per_block = slice_length[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
const float* p_src1 = reinterpret_cast<const float*>(src1_buf.p_data_) + src1_offset;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
if constexpr(Src1AlongDim0)
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
*(p_src1 + 1),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
*(p_src1 + 2),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
*(p_src1 + 3),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 4 * DstGemmN,
p_src + 4 * n_per_block,
*(p_src1 + 4),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 5 * DstGemmN,
p_src + 5 * n_per_block,
*(p_src1 + 5),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 6 * DstGemmN,
p_src + 6 * n_per_block,
*(p_src1 + 6),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 7 * DstGemmN,
p_src + 7 * n_per_block,
*(p_src1 + 7),
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src += 8 * n_per_block;
p_src1 += 8;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
*(p_src1 + 1),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
*(p_src1 + 2),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
*(p_src1 + 3),
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
p_src1 += 4;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
*(p_src1 + 1),
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
p_src1 += 2;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
*(p_src1 + 0),
current_n,
element_op_);
}
}
else
{
while(i_m_itr >= 8)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 4 * DstGemmN,
p_src + 4 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 5 * DstGemmN,
p_src + 5 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 6 * DstGemmN,
p_src + 6 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 7 * DstGemmN,
p_src + 7 * n_per_block,
p_src1,
current_n,
element_op_);
i_m_itr -= 8;
p_dst += 8 * DstGemmN;
p_src += 8 * n_per_block;
}
if(i_m_itr & 4)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 2 * DstGemmN,
p_src + 2 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 3 * DstGemmN,
p_src + 3 * n_per_block,
p_src1,
current_n,
element_op_);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block;
}
if(i_m_itr & 2)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
current_n,
element_op_);
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 1 * DstGemmN,
p_src + 1 * n_per_block,
p_src1,
current_n,
element_op_);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block;
}
if(i_m_itr & 1)
{
avx2_util::memcpy32_avx2_with_extra_1src(p_dst + 0 * DstGemmN,
p_src + 0 * n_per_block,
p_src1,
current_n,
element_op_);
}
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc&, const Index&) {}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_;
ck::index_t i_dst_gemm_m;
ck::index_t i_dst_gemm_n;
ck::index_t DstGemmM;
ck::index_t DstGemmN;
intptr_t src_offset;
intptr_t src1_offset;
intptr_t dst_offset;
};
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
......
...@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false; ...@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
...@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver ...@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, fuse_bias, fuse_add, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, true, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)
// clang-format on
));
}
/****************************************************************************************************/
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 256, 128, 64, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, false, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 24, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 32, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 40, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 48, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 48, 48, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 56, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 72, 16, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 72, 16, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 72, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 72, 32, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 96, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 96, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 120, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 120, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
/****************************************************************************************************/
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 256, 128, 64, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, false, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 24, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 32, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 40, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 48, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 48, 48, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 56, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 72, 16, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 72, 16, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 72, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 72, 32, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 96, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 96, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 120, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 120, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on // clang-format on
)); ));
} }
......
...@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false; ...@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
...@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver ...@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, fuse_bias, fuse_add, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, true, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)
// clang-format on
));
}
/****************************************************************************************************/
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 256, 128, 64, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, false, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 24, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 32, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 40, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 48, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 48, 48, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 56, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 72, 16, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 72, 16, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 72, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 72, 32, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 96, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 96, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 120, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 120, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
/****************************************************************************************************/
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 256, 128, 64, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, false, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 24, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 32, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 40, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 48, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 48, 48, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 56, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 72, 16, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 72, 16, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 72, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 72, 32, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 96, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 96, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 120, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 120, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on // clang-format on
)); ));
} }
......
...@@ -21,6 +21,8 @@ static constexpr bool NonTemporalStore = false; ...@@ -21,6 +21,8 @@ static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::cpu::element_wise::AddReluAdd;
using AddRelu = ck::tensor_operation::cpu::element_wise::AddRelu;
using Add = ck::tensor_operation::cpu::element_wise::Add;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
...@@ -40,95 +42,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver ...@@ -40,95 +42,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \ #define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, fuse_bias, fuse_add, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\ \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \ DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}) DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, fuse_bias, fuse_add, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on // clang-format on
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, false, true, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 64, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)
// clang-format on // clang-format on
)); ));
} }
void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt( void add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances) std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddReluAdd>>& instances)
{ {
ck::tensor_operation::device::add_device_operation_instances( ck::tensor_operation::device::add_device_operation_instances(
instances, instances,
std::make_tuple( std::make_tuple(
// clang-format off // clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 24, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 32, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 40, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 48, 48, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 56, 24, 256, 4, 24, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 16, 256, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 72, 32, 256, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 96, 64, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 32, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 120, 64, 128, 6, 16, false, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 256, 128, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 128, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 240, 128, 4, 24, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 512, 256, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 768, 320, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, false), DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 896, 352, 128, 6, 16, true, true, true, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, false) DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddReluAdd, 1024, 416, 128, 6, 16, true, true, true, false)
// clang-format on
));
}
/****************************************************************************************************/
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 256, 128, 64, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, false, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, AddRelu>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 24, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 32, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 40, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 48, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 48, 48, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 56, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 72, 16, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 72, 16, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 72, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 72, 32, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 96, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 96, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 120, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 120, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, AddRelu, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
/****************************************************************************************************/
void add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 256, 128, 64, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, false, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_local_c(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 256, 128, 64, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on
));
}
void add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_mt(
std::vector<DeviceConvFwdBiasActivationAddPtr<PT, PT, Add>>& instances)
{
ck::tensor_operation::device::add_device_operation_instances(
instances,
std::make_tuple(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 24, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 32, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 40, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 48, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 48, 48, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 56, 24, 256, 4, 24, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 72, 16, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 72, 16, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 72, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 72, 32, 256, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 96, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 96, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 120, 32, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 120, 64, 128, 6, 16, false, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 256, 128, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 128, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 512, 240, 128, 4, 24, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 512, 256, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 768, 320, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 896, 352, 128, 6, 16, true, true, false, false),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, Add, 1024, 416, 128, 6, 16, true, true, false, false)
// clang-format on // clang-format on
)); ));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment