Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f8b551da
"docs/vscode:/vscode.git/clone" did not exist on "2c915218e85c81558b66cff23ef92e646e6442f7"
Commit
f8b551da
authored
Jun 14, 2022
by
carlushuang
Browse files
add bias_relu, bias fusion
parent
bfa4c686
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1672 additions
and
232 deletions
+1672
-232
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+247
-24
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
+47
-13
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
..._convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
+48
-14
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
+47
-13
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+623
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
+220
-56
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
...2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
+220
-56
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
+220
-56
No files found.
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
f8b551da
...
@@ -8,12 +8,18 @@
...
@@ -8,12 +8,18 @@
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "reference_conv_fwd_bias_activation_add.hpp"
#include "reference_conv_fwd_bias_activation_add.hpp"
#include "reference_conv_fwd_bias_activation.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <omp.h>
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_BIAS_RELU_ADD 0
#define TEST_FUSION_BIAS_RELU 1
#define TEST_FUSION_BIAS 2
#define TEST_FUSION TEST_FUSION_BIAS
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
...
@@ -30,46 +36,102 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
...
@@ -30,46 +36,102 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
using
Add
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
// ------------------ nhwc-kyxc-nhwk
// ------------------ nhwc-kyxc-nhwk
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk_mt
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
// ------------------ nhwc-kcyxk8-nhwk
// ------------------ nhwc-kcyxk8-nhwk
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk_mt
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
// ------------------ nhwc-yxck-nhwk
// ------------------ nhwc-yxck-nhwk
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk_mt
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
void
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>>&
instances
);
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -78,7 +140,13 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
...
@@ -78,7 +140,13 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt(
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
#elif TEST_FUSION == TEST_FUSION_BIAS_RELU
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
#elif TEST_FUSION == TEST_FUSION_BIAS
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
#endif
template
<
typename
T
>
template
<
typename
T
>
static
bool
static
bool
...
@@ -249,6 +317,7 @@ int main(int argc, char* argv[])
...
@@ -249,6 +317,7 @@ int main(int argc, char* argv[])
using
WeiDataType
=
decltype
(
wei_type
);
using
WeiDataType
=
decltype
(
wei_type
);
using
OutDataType
=
decltype
(
out_type
);
using
OutDataType
=
decltype
(
out_type
);
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
using
ReferenceConvFwdInstance
=
using
ReferenceConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd_Bias_Activation_Add
<
InDataType
,
ck
::
tensor_operation
::
host
::
ReferenceConvFwd_Bias_Activation_Add
<
InDataType
,
WeiDataType
,
WeiDataType
,
...
@@ -256,6 +325,15 @@ int main(int argc, char* argv[])
...
@@ -256,6 +325,15 @@ int main(int argc, char* argv[])
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
>
;
OutElementOp
>
;
#else
using
ReferenceConvFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd_Bias_Activation
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
#endif
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
@@ -381,7 +459,9 @@ int main(int argc, char* argv[])
...
@@ -381,7 +459,9 @@ int main(int argc, char* argv[])
wei_k_c_y_x
,
wei_k_c_y_x
,
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_host_result
,
bias
,
bias
,
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
residual
,
residual
,
#endif
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -394,9 +474,19 @@ int main(int argc, char* argv[])
...
@@ -394,9 +474,19 @@ int main(int argc, char* argv[])
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
using
Add
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>
;
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>
;
#elif TEST_FUSION == TEST_FUSION_BIAS_RELU
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddRelu
>
;
#elif TEST_FUSION == TEST_FUSION_BIAS
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
Add
>
;
#endif
// add device Conv instances
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
...
@@ -405,27 +495,27 @@ int main(int argc, char* argv[])
...
@@ -405,27 +495,27 @@ int main(int argc, char* argv[])
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
{
#if TEST_FUSION == TEST_FUSION_BIAS_RELU_ADD
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if
(
omp_get_max_threads
()
>
1
)
if
(
omp_get_max_threads
()
>
1
)
{
{
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk_mt
(
conv_ptrs
);
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
}
else
else
{
{
if
(
K
%
8
==
0
)
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
conv_ptrs
);
else
else
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk_local_c
(
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
}
#endif
#endif
...
@@ -434,23 +524,21 @@ int main(int argc, char* argv[])
...
@@ -434,23 +524,21 @@ int main(int argc, char* argv[])
{
{
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt
(
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_mt
(
conv_ptrs
);
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
}
}
else
else
{
{
if
(
K
%
8
==
0
)
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
conv_ptrs
);
else
else
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk_local_c
(
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
}
#endif
#endif
...
@@ -459,24 +547,159 @@ int main(int argc, char* argv[])
...
@@ -459,24 +547,159 @@ int main(int argc, char* argv[])
{
{
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
}
}
else
else
{
{
if
(
K
%
8
==
0
)
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_yxck_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
#endif
#elif TEST_FUSION == TEST_FUSION_BIAS_RELU
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_local_c
(
conv_ptrs
);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
}
else
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_local_c
(
conv_ptrs
);
}
#endif
#elif TEST_FUSION == TEST_FUSION_BIAS
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_local_c
(
conv_ptrs
);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_local_c
(
conv_ptrs
);
}
#endif
#endif
#endif
}
}
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
View file @
f8b551da
...
@@ -37,6 +37,8 @@ template <typename InDataType,
...
@@ -37,6 +37,8 @@ template <typename InDataType,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
,
bool
UseCLocalBuffer
,
bool
FuseBias
,
bool
FuseAdd
,
bool
BiasAlongGemmM
>
bool
BiasAlongGemmM
>
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
...
@@ -607,8 +609,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -607,8 +609,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
!
UseBLocalBuffer
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
>
;
ConvForwardSpecialization
>
;
using
CThreadwiseCopy
=
static
constexpr
auto
GetCThreadwiseCopy
()
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
{
constexpr
ck
::
index_t
C_nDim
=
CGridDesc
::
GetNumOfDimension
();
if
constexpr
(
FuseBias
&&
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
CDataType
,
CDataType
,
C0DataType
,
C0DataType
,
C1DataType
,
C1DataType
,
...
@@ -619,7 +626,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -619,7 +626,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
decltype
(
GetOutputBlockDescriptor
()),
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
OutElementwiseOperation
,
!
UseCLocalBuffer
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
;
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
else
if
constexpr
(
FuseBias
&&
!
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
<
CDataType
,
C0DataType
,
C1DataType
,
CDataType
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
}
using
CThreadwiseCopy
=
decltype
(
GetCThreadwiseCopy
());
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
ADataType
,
// InDataType,
ADataType
,
// InDataType,
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
View file @
f8b551da
...
@@ -37,6 +37,8 @@ template <typename InDataType,
...
@@ -37,6 +37,8 @@ template <typename InDataType,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
,
bool
UseCLocalBuffer
,
bool
FuseBias
,
bool
FuseAdd
,
bool
BiasAlongGemmM
>
bool
BiasAlongGemmM
>
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
...
@@ -584,8 +586,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -584,8 +586,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
!
UseBLocalBuffer
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
>
;
ConvForwardSpecialization
>
;
using
CThreadwiseCopy
=
static
constexpr
auto
GetCThreadwiseCopy
()
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
{
constexpr
ck
::
index_t
C_nDim
=
CGridDesc
::
GetNumOfDimension
();
if
constexpr
(
FuseBias
&&
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
CDataType
,
CDataType
,
C0DataType
,
C0DataType
,
C1DataType
,
C1DataType
,
...
@@ -596,7 +603,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -596,7 +603,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
decltype
(
GetOutputBlockDescriptor
()),
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
OutElementwiseOperation
,
!
UseCLocalBuffer
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
;
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
else
if
constexpr
(
FuseBias
&&
!
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
<
CDataType
,
C0DataType
,
C1DataType
,
CDataType
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
}
using
CThreadwiseCopy
=
decltype
(
GetCThreadwiseCopy
());
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
ADataType
,
// InDataType,
ADataType
,
// InDataType,
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
View file @
f8b551da
...
@@ -36,6 +36,8 @@ template <typename InDataType,
...
@@ -36,6 +36,8 @@ template <typename InDataType,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
,
bool
UseCLocalBuffer
,
bool
FuseBias
,
bool
FuseAdd
,
bool
BiasAlongGemmM
>
bool
BiasAlongGemmM
>
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
...
@@ -580,8 +582,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -580,8 +582,13 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
!
UseBLocalBuffer
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
>
;
ConvForwardSpecialization
>
;
using
CThreadwiseCopy
=
static
constexpr
auto
GetCThreadwiseCopy
()
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
{
constexpr
ck
::
index_t
C_nDim
=
CGridDesc
::
GetNumOfDimension
();
if
constexpr
(
FuseBias
&&
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
CDataType
,
CDataType
,
C0DataType
,
C0DataType
,
C1DataType
,
C1DataType
,
...
@@ -592,7 +599,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -592,7 +599,34 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
decltype
(
GetOutputBlockDescriptor
()),
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
OutElementwiseOperation
,
!
UseCLocalBuffer
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
;
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
else
if
constexpr
(
FuseBias
&&
!
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
<
CDataType
,
C0DataType
,
C1DataType
,
CDataType
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
}
using
CThreadwiseCopy
=
decltype
(
GetCThreadwiseCopy
());
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
ADataType
,
// InDataType,
ADataType
,
// InDataType,
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
f8b551da
...
@@ -273,6 +273,53 @@ void memcpy32_avx2_with_extra_1src(void* dst,
...
@@ -273,6 +273,53 @@ void memcpy32_avx2_with_extra_1src(void* dst,
}
}
}
}
template
<
typename
ElementwiseOp
>
void
memcpy32_avx2_with_extra_1src
(
void
*
dst
,
const
void
*
src
,
const
float
v_src_aux
,
const
ck
::
index_t
n
,
const
ElementwiseOp
&
element_op
)
{
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
__m256
ymm_src_aux
=
_mm256_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
v_src_aux
));
__m128
xmm_src_aux
=
_mm_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
v_src_aux
));
while
(
i_n
>=
16
)
{
_mm256_storeu_ps
(
p_dst
+
0
,
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
0
),
ymm_src_aux
));
_mm256_storeu_ps
(
p_dst
+
8
,
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
8
),
ymm_src_aux
));
p_dst
+=
16
;
p_src
+=
16
;
i_n
-=
16
;
}
if
(
i_n
&
8
)
{
_mm256_storeu_ps
(
p_dst
,
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
),
ymm_src_aux
));
p_dst
+=
8
;
p_src
+=
8
;
}
if
(
i_n
&
4
)
{
_mm_storeu_ps
(
p_dst
,
element_op
.
Apply
(
_mm_loadu_ps
(
p_src
),
xmm_src_aux
));
p_dst
+=
4
;
p_src
+=
4
;
}
if
(
i_n
&
2
)
{
_mm_storeu_si64
(
p_dst
,
element_op
.
Apply
(
_mm_loadu_si64
(
p_src
),
xmm_src_aux
));
p_dst
+=
2
;
p_src
+=
2
;
}
if
(
i_n
&
1
)
{
*
p_dst
=
element_op
.
Apply
(
*
p_src
,
v_src_aux
);
}
}
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
{
{
// 16-8-4-2-1 pattern
// 16-8-4-2-1 pattern
...
@@ -2369,6 +2416,582 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
...
@@ -2369,6 +2416,582 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
intptr_t
dst_offset
;
intptr_t
dst_offset
;
};
};
template
<
typename
SrcData
,
typename
Src1Data
,
// for Bias, per dimension
typename
Src2Data
,
// for Residual, per pixel
typename
DstData
,
typename
SrcDesc
,
typename
Src1Desc
,
typename
Src2Desc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
bool
Src1AlongDim0
>
// if true, src1 has dim along M, false, src1 has dim along N
struct
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
dst_desc
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
DstGemmM
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
DstGemmN
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
src_offset
=
0
;
src1_offset
=
0
;
dst_offset
=
0
;
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
if
constexpr
(
BypassTransfer
)
{
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
i_src_gemm_m
*
DstGemmN
+
i_src_gemm_n
;
}
}
void
SetSrc1SliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
if
constexpr
(
Src1AlongDim0
)
{
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
// auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src1_offset
=
i_src_gemm_m
;
}
else
{
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src1_offset
=
i_src_gemm_n
;
}
}
void
SetSrc2SliceOrigin
(
const
SrcDesc
&
,
const
Index
&
)
{}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
dst_slice_origin_idx
)
{
i_dst_gemm_m
=
dst_slice_origin_idx
[
Number
<
0
>
{}];
i_dst_gemm_n
=
dst_slice_origin_idx
[
Number
<
1
>
{}];
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
template
<
typename
SrcBuffer
,
typename
Src1Buffer
,
typename
Src2Buffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
SrcBuffer
&
src_buf
,
const
Src1Desc
&
,
Src1Buffer
&
,
const
Src2Desc
&
,
Src2Buffer
&
,
const
DstDesc
&
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
)
{
if
constexpr
(
BypassTransfer
)
{
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
}
template
<
typename
SrcBuffer
,
typename
Src1Buffer
,
typename
Src2Buffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunWrite
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
Src1Desc
&
src1_desc
,
Src1Buffer
&
src1_buf
,
const
Src2Desc
&
,
Src2Buffer
&
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{
if
constexpr
(
!
std
::
is_same
<
ElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
const
float
*
p_src1
=
reinterpret_cast
<
const
float
*>
(
src1_buf
.
p_data_
)
+
src1_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// standard 8-4-2-1 pattern
if
constexpr
(
Src1AlongDim0
)
{
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
*
(
p_src1
+
1
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_dst
+
2
*
DstGemmN
,
*
(
p_src1
+
2
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_dst
+
3
*
DstGemmN
,
*
(
p_src1
+
3
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
4
*
DstGemmN
,
p_dst
+
4
*
DstGemmN
,
*
(
p_src1
+
4
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
5
*
DstGemmN
,
p_dst
+
5
*
DstGemmN
,
*
(
p_src1
+
5
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
6
*
DstGemmN
,
p_dst
+
6
*
DstGemmN
,
*
(
p_src1
+
6
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
7
*
DstGemmN
,
p_dst
+
7
*
DstGemmN
,
*
(
p_src1
+
7
),
current_n
,
element_op_
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_src1
+=
8
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
*
(
p_src1
+
1
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_dst
+
2
*
DstGemmN
,
*
(
p_src1
+
2
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_dst
+
3
*
DstGemmN
,
*
(
p_src1
+
3
),
current_n
,
element_op_
);
p_dst
+=
4
*
DstGemmN
;
p_src1
+=
4
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
*
(
p_src1
+
1
),
current_n
,
element_op_
);
p_dst
+=
2
*
DstGemmN
;
p_src1
+=
2
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
}
}
else
{
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_dst
+
2
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_dst
+
3
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
4
*
DstGemmN
,
p_dst
+
4
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
5
*
DstGemmN
,
p_dst
+
5
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
6
*
DstGemmN
,
p_dst
+
6
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
7
*
DstGemmN
,
p_dst
+
7
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_dst
+
2
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_dst
+
3
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
p_dst
+=
4
*
DstGemmN
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
p_dst
+=
2
*
DstGemmN
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
p_src1
,
current_n
,
element_op_
);
}
}
}
}
else
{
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
const
float
*
p_src1
=
reinterpret_cast
<
const
float
*>
(
src1_buf
.
p_data_
)
+
src1_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
if
constexpr
(
Src1AlongDim0
)
{
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
*
(
p_src1
+
1
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
*
(
p_src1
+
2
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
*
(
p_src1
+
3
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
*
(
p_src1
+
4
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
*
(
p_src1
+
5
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
*
(
p_src1
+
6
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
*
(
p_src1
+
7
),
current_n
,
element_op_
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_src
+=
8
*
n_per_block
;
p_src1
+=
8
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
*
(
p_src1
+
1
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
*
(
p_src1
+
2
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
*
(
p_src1
+
3
),
current_n
,
element_op_
);
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
p_src1
+=
4
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
*
(
p_src1
+
1
),
current_n
,
element_op_
);
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
p_src1
+=
2
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
*
(
p_src1
+
0
),
current_n
,
element_op_
);
}
}
else
{
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_src
+=
8
*
n_per_block
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2_with_extra_1src
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
p_src1
,
current_n
,
element_op_
);
}
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_dst_gemm_m
;
ck
::
index_t
i_dst_gemm_n
;
ck
::
index_t
DstGemmM
;
ck
::
index_t
DstGemmN
;
intptr_t
src_offset
;
intptr_t
src1_offset
;
intptr_t
dst_offset
;
};
}
// namespace cpu
}
// namespace cpu
}
// namespace ck
}
// namespace ck
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
f8b551da
...
@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false;
...
@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
using
Add
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
static
constexpr
auto
ConvFwdDefault
=
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
...
@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
...
@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf,
fuse_bias, fuse_add,
bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
true
,
false
)
// clang-format on
// clang-format on
));
));
}
}
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
// clang-format on
// clang-format on
));
));
}
}
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxc_nhwk_mt
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
// clang-format on
));
}
/****************************************************************************************************/
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
/****************************************************************************************************/
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
// clang-format on
));
));
}
}
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
f8b551da
...
@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false;
...
@@ -22,6 +22,8 @@ static constexpr bool NonTemporalStore = false;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
using
Add
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
static
constexpr
auto
ConvFwdDefault
=
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
...
@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
...
@@ -41,95 +43,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf,
fuse_bias, fuse_add,
bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
true
,
false
)
// clang-format on
// clang-format on
));
));
}
}
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
// clang-format on
// clang-format on
));
));
}
}
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_kyxck8_nhwk_mt
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
// clang-format on
));
}
/****************************************************************************************************/
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
/****************************************************************************************************/
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
// clang-format on
));
));
}
}
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
View file @
f8b551da
...
@@ -21,6 +21,8 @@ static constexpr bool NonTemporalStore = false;
...
@@ -21,6 +21,8 @@ static constexpr bool NonTemporalStore = false;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
using
Add
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
static
constexpr
auto
ConvFwdDefault
=
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
...
@@ -40,95 +42,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
...
@@ -40,95 +42,257 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf,
fuse_bias, fuse_add,
bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MNK}), \
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, true, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, 2, m_per_thread, n_per_thread, false, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, GemmKLoopOverC , LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, 2, m_per_thread, n_per_thread, true, false, c_local_buf,
fuse_bias, fuse_add,
bias_along_m>({m_per_block, n_per_block, k_per_block, DefaultGemmKLoop, LoopOver_MKN})
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
true
,
false
)
// clang-format on
// clang-format on
));
));
}
}
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
// clang-format on
// clang-format on
));
));
}
}
void
add_device_conv2d_fwd_bias_
activation
_add_avx2_nhwc_yxck_nhwk_mt
(
void
add_device_conv2d_fwd_bias_
relu
_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
std
::
make_tuple
(
std
::
make_tuple
(
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
// clang-format on
));
}
/****************************************************************************************************/
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_relu_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddRelu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddRelu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
/****************************************************************************************************/
void
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
64
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
false
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
));
}
void
add_device_conv2d_fwd_bias_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
Add
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
24
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
32
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
40
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
48
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
48
,
48
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
56
,
24
,
256
,
4
,
24
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
16
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
16
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
72
,
32
,
256
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
96
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
96
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
120
,
32
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
120
,
64
,
128
,
6
,
16
,
false
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Add
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
// clang-format on
// clang-format on
));
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment