Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f9cf57d4
You need to sign in or sign up before continuing.
Commit
f9cf57d4
authored
Jun 07, 2022
by
carlushuang
Browse files
support YXCK filter
parent
71254ddd
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
3299 additions
and
677 deletions
+3299
-677
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
+658
-565
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+74
-1
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+10
-4
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
+927
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
+992
-0
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+86
-42
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+132
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
...c/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
+229
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/CMakeLists.txt
...nstance/cpu/conv2d_fwd_bias_activation_add/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
+144
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+45
-65
No files found.
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
View file @
f9cf57d4
...
@@ -20,7 +20,8 @@
...
@@ -20,7 +20,8 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -34,6 +35,7 @@ namespace device_conv2d_fwd_avx2_instance {
...
@@ -34,6 +35,7 @@ namespace device_conv2d_fwd_avx2_instance {
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
// ------------------ nhwc-kyxc-nhwk
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
...
@@ -52,6 +54,7 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
...
@@ -52,6 +54,7 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
// ------------------ nhwc-kcyxk8-nhwk
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
...
@@ -70,6 +73,25 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
...
@@ -70,6 +73,25 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu(
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
// ------------------ nhwc-yxck-nhwk
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -162,6 +184,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
...
@@ -162,6 +184,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
}
}
}
}
template
<
typename
T
>
void
transpose_kyxc_2_yxck
(
Tensor
<
T
>&
dst
,
const
Tensor
<
T
>&
src
,
ck
::
index_t
K
,
ck
::
index_t
Y
,
ck
::
index_t
X
,
ck
::
index_t
C
)
{
ck
::
index_t
batch
=
1
;
ck
::
index_t
row
=
K
;
ck
::
index_t
col
=
C
*
Y
*
X
;
for
(
auto
i_b
=
0
;
i_b
<
batch
;
i_b
++
)
{
for
(
auto
i_r
=
0
;
i_r
<
row
;
i_r
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
ck
::
index_t
src_idx
=
i_b
*
row
*
col
+
i_r
*
col
+
i_c
;
ck
::
index_t
dst_idx
=
i_b
*
col
*
row
+
i_c
*
row
+
i_r
;
dst
.
mData
[
dst_idx
]
=
src
.
mData
[
src_idx
];
}
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
int
data_type
=
0
;
int
data_type
=
0
;
...
@@ -263,6 +310,10 @@ int main(int argc, char* argv[])
...
@@ -263,6 +310,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor
<
WeiDataType
>
wei_y_x_c_k
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
#endif
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
...
@@ -353,6 +404,10 @@ int main(int argc, char* argv[])
...
@@ -353,6 +404,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck
(
wei_y_x_c_k
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_y_x_c_k
.
mData
.
data
());
#endif
#endif
// get host result
// get host result
{
{
...
@@ -465,6 +520,44 @@ int main(int argc, char* argv[])
...
@@ -465,6 +520,44 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
conv_ptrs
);
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
conv_ptrs
);
}
}
#endif
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c
(
conv_ptrs
);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu
(
conv_ptrs
);
}
#endif
#endif
#endif
}
}
...
...
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
f9cf57d4
...
@@ -16,7 +16,8 @@
...
@@ -16,7 +16,8 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
#define TEST_LAYOUT_NHWC_YXCK_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXC_NHWK
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
...
@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
// ------------------ nhwc-kyxc-nhwk
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
...
@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
...
@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
// ------------------ nhwc-kcyxk8-nhwk
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
...
@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
...
@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
// ------------------ nhwc-yxck-nhwk
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
...
@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
}
}
}
}
template
<
typename
T
>
void
transpose_kyxc_2_yxck
(
Tensor
<
T
>&
dst
,
const
Tensor
<
T
>&
src
,
ck
::
index_t
K
,
ck
::
index_t
Y
,
ck
::
index_t
X
,
ck
::
index_t
C
)
{
ck
::
index_t
batch
=
1
;
ck
::
index_t
row
=
K
;
ck
::
index_t
col
=
C
*
Y
*
X
;
for
(
auto
i_b
=
0
;
i_b
<
batch
;
i_b
++
)
{
for
(
auto
i_r
=
0
;
i_r
<
row
;
i_r
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
ck
::
index_t
src_idx
=
i_b
*
row
*
col
+
i_r
*
col
+
i_c
;
ck
::
index_t
dst_idx
=
i_b
*
col
*
row
+
i_c
*
row
+
i_r
;
dst
.
mData
[
dst_idx
]
=
src
.
mData
[
src_idx
];
}
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
int
data_type
=
0
;
int
data_type
=
0
;
...
@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
...
@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor
<
WeiDataType
>
wei_y_x_c_k
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
#endif
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
...
@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
...
@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck
(
wei_y_x_c_k
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_y_x_c_k
.
mData
.
data
());
#endif
#endif
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
resi_device_buf
.
ToDevice
(
residual
.
mData
.
data
());
resi_device_buf
.
ToDevice
(
residual
.
mData
.
data
());
...
@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
...
@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
conv_ptrs
);
}
#endif
#endif
}
}
...
...
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
f9cf57d4
...
@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const
auto
k_per_block
=
a_slice_length
[
Number
<
1
>
{}];
const
auto
k_per_block
=
a_slice_length
[
Number
<
1
>
{}];
const
auto
m_per_block
=
c_slice_length
[
Number
<
0
>
{}];
const
auto
m_per_block
=
c_slice_length
[
Number
<
0
>
{}];
const
auto
n_per_block
=
c_slice_length
[
Number
<
1
>
{}];
const
auto
n_per_block
=
c_slice_length
[
Number
<
1
>
{}];
...
@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
param
.
alpha
=
1.0
f
;
// TODO
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc,
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u, mpt:%u, npt:%u\n",
// m_per_block, n_per_block, k_per_block);
// lda,
// ldb,
// ldc,
// m_per_block,
// n_per_block,
// k_per_block,
// m_per_thread,
// n_per_thread);
// fflush(stdout);
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
{
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
0 → 100644
View file @
f9cf57d4
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
>
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
;
using
ADataType
=
InDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
OutDataType
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_k
,
gemm_n
));
}
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
{
const
auto
out_gemm_m_n_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
return
out_gemm_m_n_grid_desc
;
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
{
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
{
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
end
(
filter_spatial_lengths
),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
GetGemmN
(
K
);
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
const
auto
in_gemm_m_k_grid_desc
=
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
C
,
GemmM
,
GemmK
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
// B:
const
auto
wei_gemm_k_n_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
// C:
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
return
make_tuple
(
in_gemm_m_k_grid_desc
,
wei_gemm_k_n_grid_desc
,
out_gemm_m_n_grid_desc
);
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
NPerBlock
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
InDataType
,
InDataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
<
WeiDataType
,
WeiDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
OutDataType
,
OutDataType
,
CGridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
WeiDataType
,
// WeiDataType,
OutDataType
,
// OutDataType,
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_in_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
a_grid_desc_
{},
b_grid_desc_
{},
c_grid_desc_
{},
a_element_op_
{
in_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
out_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_
=
descs
[
I0
];
b_grid_desc_
=
descs
[
I1
];
c_grid_desc_
=
descs
[
I2
];
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc
a_grid_desc_
;
BGridDesc
b_grid_desc_
;
CGridDesc
c_grid_desc_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
InDataType
,
WeiDataType
,
OutDataType
,
AGridDesc
,
BGridDesc
,
CGridDesc
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
float
ave_time
=
0
;
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
if
constexpr
(
!
UseALocalBuffer
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
if
constexpr
(
!
UseBLocalBuffer
)
{
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwdAvx2_NHWC_YXCK"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
0 → 100644
View file @
f9cf57d4
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_YXCK_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
BiasDataType
,
typename
AddDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
,
bool
BiasAlongGemmM
>
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
;
using
ADataType
=
InDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
OutDataType
;
using
C0DataType
=
BiasDataType
;
using
C1DataType
=
AddDataType
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
NPerBlock
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_k
,
gemm_n
));
}
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
{
const
auto
out_gemm_m_n_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
return
out_gemm_m_n_grid_desc
;
}
static
auto
MakeBiasTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
{
if
constexpr
(
BiasAlongGemmM
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
));
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
{
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
{
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
end
(
filter_spatial_lengths
),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
GetGemmN
(
K
);
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
const
auto
in_gemm_m_k_grid_desc
=
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
C
,
GemmM
,
GemmK
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
// B:
const
auto
wei_gemm_k_n_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
// C:
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
return
make_tuple
(
in_gemm_m_k_grid_desc
,
wei_gemm_k_n_grid_desc
,
out_gemm_m_n_grid_desc
);
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
ADataType
,
ADataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
<
BDataType
,
BDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
CDataType
,
C0DataType
,
C1DataType
,
CDataType
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
;
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
ADataType
,
// InDataType,
BDataType
,
// WeiDataType,
CDataType
,
// OutDataType,
C0DataType
,
// C0DataType
C1DataType
,
// C1DataType
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
C0GridDesc
,
// C0GridDesc,
C1GridDesc
,
// C1GridDesc,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
const
BiasDataType
*
p_bias_grid
,
const
AddDataType
*
p_add_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_in_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
p_c0_grid_
{
p_bias_grid
},
p_c1_grid_
{
p_add_grid
},
a_grid_desc_
{},
b_grid_desc_
{},
c_grid_desc_
{},
c0_grid_desc_
{},
c1_grid_desc_
{},
a_element_op_
{
in_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
out_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_
=
descs
[
I0
];
b_grid_desc_
=
descs
[
I1
];
c_grid_desc_
=
descs
[
I2
];
c0_grid_desc_
=
DeviceOp
::
MakeBiasTensorDescriptor
(
GetGemmM
(
N
,
output_spatial_lengths
),
GetGemmN
(
K
));
c1_grid_desc_
=
descs
[
I2
];
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
C0DataType
*
p_c0_grid_
;
const
C1DataType
*
p_c1_grid_
;
AGridDesc
a_grid_desc_
;
BGridDesc
b_grid_desc_
;
CGridDesc
c_grid_desc_
;
C0GridDesc
c0_grid_desc_
;
C1GridDesc
c1_grid_desc_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_bias_activation_add_avx_mxn
<
GridwiseGemm
,
ADataType
,
BDataType
,
CDataType
,
C0DataType
,
C1DataType
,
AGridDesc
,
BGridDesc
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
float
ave_time
=
0
;
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
c0_grid_desc_
,
arg
.
c1_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
c0_grid_desc_
,
arg
.
c1_grid_desc_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
if
constexpr
(
!
UseALocalBuffer
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
if
constexpr
(
!
UseBLocalBuffer
)
{
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
const
BiasDataType
*
p_bias_grid
,
const
AddDataType
*
p_add_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
p_bias_grid
,
p_add_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
const
void
*
p_bias_grid
,
const
void
*
p_add_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
BiasDataType
*>
(
p_bias_grid
),
static_cast
<
const
AddDataType
*>
(
p_add_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwd_BAA_Avx2_NHWC_YXCK"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
f9cf57d4
...
@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
".if m_TransA != 0
\n
"
"movq 32(%[m_param]), %%rcx
\n
"
// lda
"movq 32(%[m_param]), %%rcx
\n
"
// lda
".endif
\n
"
".if m_TransB == 0
\n
"
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
".endif
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
".if
\\
i_scale != 0
\n
"
...
@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8
, r9
), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_ABytes == 4
\n
"
".if m_ABytes == 4
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx, (
\\
i_k-3), (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
...
@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vpbroadcastw_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes), %%xmm15
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes), %%xmm15
\n
"
".else
\n
"
"vpbroadcastw_%= %%rax, %%rcx, (
\\
i_k-3), (
\\
i_m * m_ABytes), %%xmm15
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
...
@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx
(r9)
, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
"vmovups_%= %%r9, %%rdx, (
\\
i_k-3), (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vcvtph2ps_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
"vcvtph2ps_%= %%r9, %%rdx, (
\\
i_k-3), (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
"lea (%%rdx, %%rdx, 2), %%rdi
\n
"
"lea (%%rbx, %%rdi), %%r9
\n
"
".endif
\n
"
".endif
\n
"
"cmp $4, %%rsi
\n
"
"cmp $4, %%rsi
\n
"
...
@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea 4*m_ABytes(%%rax), %%rax
\n
"
" lea 4*m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 4), %%rax
\n
"
" lea (%%r8, %%rcx, 4), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 4), %%rbx
\n
"
" lea (%%r9, %%rdx, 4), %%r9
\n
"
".else
\n
"
".else
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea m_ABytes(%%rax), %%rax
\n
"
" lea m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 1), %%rax
\n
"
" lea (%%r8, %%rcx, 1), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 1), %%rbx
\n
"
" lea (%%r9, %%rdx, 1), %%r9
\n
"
".else
\n
"
".else
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
else
else
{
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
Mr
+
i_m
);
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
lda
+
i_m
);
}
}
}
}
else
else
...
@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
else
else
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
Mr
+
i_m
)));
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
lda
+
i_m
)));
}
}
}
}
};
};
...
@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
{
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
Nr
+
i_n
*
8
);
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
ldb
+
i_n
*
8
);
}
}
else
else
{
{
...
@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
Nr
+
i_n
*
8
)));
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
ldb
+
i_n
*
8
)));
}
}
else
else
{
{
...
@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
4
;
p_a
+=
4
;
}
else
{
}
else
{
p_a
+=
Mr
*
4
;
p_a
+=
lda
*
4
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
4
;
p_b
+=
ldb
*
4
;
}
else
{
}
else
{
p_b
+=
4
*
8
;
p_b
+=
4
*
8
;
}
}
...
@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
1
;
p_a
+=
1
;
}
else
{
}
else
{
p_a
+=
Mr
*
1
;
p_a
+=
lda
*
1
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
1
;
p_b
+=
ldb
*
1
;
}
else
{
}
else
{
p_b
+=
1
*
8
;
p_b
+=
1
*
8
;
}
}
...
@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
".if m_TransA != 0
\n
"
"movq 32(%[m_param]), %%rcx
\n
"
// lda
"movq 32(%[m_param]), %%rcx
\n
"
// lda
".endif
\n
"
".if m_TransB == 0
\n
"
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
".endif
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
".if
\\
i_scale != 0
\n
"
...
@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_ABytes == 4
\n
"
".if m_ABytes == 4
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx, (
\\
i_k-2), (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
...
@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vpbroadcastw_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes), %%xmm15
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes), %%xmm15
\n
"
".else
\n
"
"vpbroadcastw_%= %%r8, %%rcx, (
\\
i_k-2), (
\\
i_m * m_ABytes), %%xmm15
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
...
@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".else
\n
"
"vmovups_%= %%rdi, %%rdx, (
\\
i_k-2), (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vcvtph2ps_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".else
\n
"
"vcvtph2ps_%= %%rdi, %%rdx, (
\\
i_k-2), (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_Mr > 2
\n
"
".if m_Mr > 2
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
"lea (%%rbx, %%rdx, 2), %%rdi
\n
"
".endif
\n
"
".endif
\n
"
"cmp $4, %%rsi
\n
"
"cmp $4, %%rsi
\n
"
...
@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea 4*m_ABytes(%%rax), %%rax
\n
"
" lea 4*m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 2
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 2
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 4), %%rax
\n
"
" lea (%%r8, %%rcx, 4), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 4), %%rbx
\n
"
" lea (%%rdi, %%rdx, 4), %%rdi
\n
"
".else
\n
"
".else
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea m_ABytes(%%rax), %%rax
\n
"
" lea m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 1), %%rax
\n
"
" lea (%%r8, %%rcx, 1), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 1), %%rbx
\n
"
" lea (%%rdi, %%rdx, 1), %%rdi
\n
"
".else
\n
"
".else
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}
}
else
else
{
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
Mr
+
i_m
);
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
lda
+
i_m
);
}
}
}
}
else
else
...
@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}
}
else
else
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
Mr
+
i_m
)));
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
lda
+
i_m
)));
}
}
}
}
};
};
...
@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
{
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
Nr
+
i_n
*
8
);
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
ldb
+
i_n
*
8
);
}
}
else
else
{
{
...
@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
Nr
+
i_n
*
8
)));
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
ldb
+
i_n
*
8
)));
}
}
else
else
{
{
...
@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
4
;
p_a
+=
4
;
}
else
{
}
else
{
p_a
+=
Mr
*
4
;
p_a
+=
lda
*
4
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
4
;
p_b
+=
ldb
*
4
;
}
else
{
}
else
{
p_b
+=
4
*
8
;
p_b
+=
4
*
8
;
}
}
...
@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
1
;
p_a
+=
1
;
}
else
{
}
else
{
p_a
+=
Mr
*
1
;
p_a
+=
lda
*
1
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
1
;
p_b
+=
ldb
*
1
;
}
else
{
}
else
{
p_b
+=
1
*
8
;
p_b
+=
1
*
8
;
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
f9cf57d4
...
@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
...
@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
intptr_t
src_offset
;
intptr_t
src_offset
;
};
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
idx_k
*
GemmN
+
idx_n
;
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
k_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// k * n
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
n_per_block
,
p_src
+
1
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
n_per_block
,
p_src
+
2
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
n_per_block
,
p_src
+
3
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
n_per_block
,
p_src
+
4
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
n_per_block
,
p_src
+
5
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
n_per_block
,
p_src
+
6
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
n_per_block
,
p_src
+
7
*
GemmN
,
n_per_block
,
element_op_
);
i_k_itr
-=
8
;
p_dst
+=
8
*
n_per_block
;
p_src
+=
8
*
GemmN
;
}
if
(
i_k_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
n_per_block
,
p_src
+
1
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
n_per_block
,
p_src
+
2
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
n_per_block
,
p_src
+
3
*
GemmN
,
n_per_block
,
element_op_
);
p_dst
+=
4
*
n_per_block
;
p_src
+=
4
*
GemmN
;
}
if
(
i_k_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
n_per_block
,
p_src
+
1
*
GemmN
,
n_per_block
,
element_op_
);
p_dst
+=
2
*
n_per_block
;
p_src
+=
2
*
GemmN
;
}
if
(
i_k_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
ck
::
index_t
move_n
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
src_offset
+=
move_k
*
GemmN
+
move_n
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
View file @
f9cf57d4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
)
)
add_library
(
device_conv2d_fwd_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_cpu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_cpu_instance PUBLIC
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
0 → 100644
View file @
f9cf57d4
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
using
InType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
,
\
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances
{});
}
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/CMakeLists.txt
View file @
f9cf57d4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
)
)
add_library
(
device_conv2d_fwd_bias_activation_add_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_activation_add_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
0 → 100644
View file @
f9cf57d4
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_add_avx2_instance
{
using
InType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
,
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
f9cf57d4
...
@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
...
@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
int
max_threads
=
omp_get_max_threads
();
int
max_threads
=
omp_get_max_threads
();
auto
invoke_uk
=
[
&
](
ck
::
cpu
::
ThreadwiseGemmParam
&
param
,
float
*
current_mat_c
)
{
auto
invoke_uk
=
[
&
](
ck
::
cpu
::
ThreadwiseGemmParam
&
param
,
float
*
current_mat_c
)
{
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
==
uk
.
ThreadNr
);
FloatA
*
p_a
=
mat_a
;
float
*
p_c
=
current_mat_c
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
ThreadMr
)
{
uk
.
Run
(
&
param
);
p_a
+=
uk
.
ThreadMr
*
k
;
p_c
+=
uk
.
ThreadMr
*
n
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
}
}
else
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Col
,
BLayout
>::
value
)
{
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
%
uk
.
ThreadNr
==
0
);
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
%
uk
.
ThreadNr
==
0
);
FloatA
*
p_a
=
mat_a
;
float
*
p_c
=
current_mat_c
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
ThreadMr
)
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
ThreadMr
)
{
{
float
*
p_c_n
=
p_c
;
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
)
FloatB
*
p_b_n
=
mat_b
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
ThreadNr
)
{
{
uk
.
Run
(
&
param
);
param
.
p_a
=
mat_a
+
i_m
*
k
;
p_b_n
+=
uk
.
ThreadNr
*
k
;
// ThreadNr/8*k*8
p_c_n
+=
uk
.
ThreadNr
;
param
.
p_b
=
p_b_n
;
param
.
p_c
=
p_c_n
;
}
p_a
+=
uk
.
ThreadMr
*
k
;
p_c
+=
uk
.
ThreadMr
*
n
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
}
}
else
{
param
.
p_a
=
mat_a
+
i_m
;
}
}
else
if
constexpr
(
std
::
is_same
<
Col
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
ThreadNr
)
{
{
assert
(
m
==
uk
.
ThreadMr
&&
n
==
uk
.
ThreadNr
);
if
constexpr
(
std
::
is_same
<
Row
,
BLayout
>::
value
)
uk
.
Run
(
&
param
);
{
param
.
p_b
=
mat_b
+
i_n
;
}
}
else
else
{
{
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
%
uk
.
ThreadNr
==
0
);
param
.
p_b
=
mat_b
+
i_n
*
k
;
FloatB
*
p_b
=
mat_b
;
}
float
*
p_c
=
current_mat_c
;
param
.
p_c
=
current_mat_c
+
i_m
*
n
+
i_n
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
ThreadNr
)
{
uk
.
Run
(
&
param
);
uk
.
Run
(
&
param
);
p_b
+=
uk
.
ThreadNr
*
k
;
// ThreadNr/8*k*8
p_c
+=
uk
.
ThreadNr
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
}
}
}
}
};
};
...
@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
...
@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
}
}
// implement small ukernel on L1
// implement small ukernel on L1
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
,
typename
thread_gemm_instance
>
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
{
int
max_threads
=
omp_get_max_threads
();
int
max_threads
=
omp_get_max_threads
();
...
@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
...
@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
k
);
k
);
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
using
thread_gemm_instance
=
thread_gemm_avx2_mxn_4x24_instances
<
ALayout
,
BLayout
>
;
//
using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool
found
=
false
;
bool
found
=
false
;
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
if
(
m
%
uk_type
::
ThreadMr
!=
0
||
n
%
uk_type
::
ThreadNr
!=
0
)
if
(
m
%
uk_type
::
ThreadMr
!=
0
||
n
%
uk_type
::
ThreadNr
!=
0
)
return
;
return
;
if
((
m
!=
uk_type
::
ThreadMr
&&
std
::
is_same
<
typename
uk_type
::
MatrixALayout
,
Col
>::
value
)
||
// if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value)
(
n
!=
uk_type
::
ThreadNr
&&
std
::
is_same
<
typename
uk_type
::
MatrixBLayout
,
Row
>::
value
))
// ||
// only k is the fast changing dim of A/B can we do muldiplt m, n
// (n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value))
return
;
// // only k is the fast changing dim of A/B can we do muldiplt m, n
// return;
if
(
found
)
if
(
found
)
return
;
return
;
...
@@ -435,8 +402,21 @@ int main(int argc, char** argv)
...
@@ -435,8 +402,21 @@ int main(int argc, char** argv)
omp_set_num_threads
(
1
);
omp_set_num_threads
(
1
);
printf
(
"max threads:%d
\n
"
,
omp_get_max_threads
());
printf
(
"max threads:%d
\n
"
,
omp_get_max_threads
());
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
,
thread_gemm_avx2_mxn_4x24_instances
<
Row
,
Row
>>
(
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
>
(
alpha
,
m
,
n
,
k
);
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
,
thread_gemm_avx2_mxn_4x24_instances
<
Row
,
Col
>>
(
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
>
(
alpha
,
m
,
n
,
k
);
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
,
thread_gemm_avx2_mxn_4x24_instances
<
Col
,
Row
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
,
thread_gemm_avx2_mxn_4x24_instances
<
Col
,
Col
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
,
thread_gemm_avx2_mxn_6x16_instances
<
Row
,
Row
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
,
thread_gemm_avx2_mxn_6x16_instances
<
Row
,
Col
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
,
thread_gemm_avx2_mxn_6x16_instances
<
Col
,
Row
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
,
thread_gemm_avx2_mxn_6x16_instances
<
Col
,
Col
>>
(
alpha
,
m
,
n
,
k
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment