Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f9cf57d4
Commit
f9cf57d4
authored
Jun 07, 2022
by
carlushuang
Browse files
support YXCK filter
parent
71254ddd
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
3299 additions
and
677 deletions
+3299
-677
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
+658
-565
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+74
-1
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+10
-4
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
+927
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
+992
-0
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
...e/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
+86
-42
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+132
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
...c/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
+229
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/CMakeLists.txt
...nstance/cpu/conv2d_fwd_bias_activation_add/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
+144
-0
test/cpu_ukernel/cpu_gemm_uk.cpp
test/cpu_ukernel/cpu_gemm_uk.cpp
+45
-65
No files found.
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
View file @
f9cf57d4
This diff is collapsed.
Click to expand it.
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
f9cf57d4
...
@@ -16,7 +16,8 @@
...
@@ -16,7 +16,8 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
#define TEST_LAYOUT_NHWC_YXCK_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXC_NHWK
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
...
@@ -30,6 +31,7 @@ namespace device_conv2d_fwd_bias_activation_add_avx2_instance {
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
// ------------------ nhwc-kyxc-nhwk
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
...
@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
...
@@ -42,6 +44,7 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
// ------------------ nhwc-kcyxk8-nhwk
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
...
@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
...
@@ -54,6 +57,19 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
instances
);
// ------------------ nhwc-yxck-nhwk
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
...
@@ -141,6 +157,31 @@ void transpose_kyxc_2_kyxc8k(Tensor<T>& dst,
}
}
}
}
template
<
typename
T
>
void
transpose_kyxc_2_yxck
(
Tensor
<
T
>&
dst
,
const
Tensor
<
T
>&
src
,
ck
::
index_t
K
,
ck
::
index_t
Y
,
ck
::
index_t
X
,
ck
::
index_t
C
)
{
ck
::
index_t
batch
=
1
;
ck
::
index_t
row
=
K
;
ck
::
index_t
col
=
C
*
Y
*
X
;
for
(
auto
i_b
=
0
;
i_b
<
batch
;
i_b
++
)
{
for
(
auto
i_r
=
0
;
i_r
<
row
;
i_r
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
ck
::
index_t
src_idx
=
i_b
*
row
*
col
+
i_r
*
col
+
i_c
;
ck
::
index_t
dst_idx
=
i_b
*
col
*
row
+
i_c
*
row
+
i_r
;
dst
.
mData
[
dst_idx
]
=
src
.
mData
[
src_idx
];
}
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
int
data_type
=
0
;
int
data_type
=
0
;
...
@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
...
@@ -243,6 +284,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
Tensor
<
WeiDataType
>
wei_y_x_c_k
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
#endif
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
...
@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
...
@@ -319,6 +364,10 @@ int main(int argc, char* argv[])
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
transpose_kyxc_2_yxck
(
wei_y_x_c_k
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_y_x_c_k
.
mData
.
data
());
#endif
#endif
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
resi_device_buf
.
ToDevice
(
residual
.
mData
.
data
());
resi_device_buf
.
ToDevice
(
residual
.
mData
.
data
());
...
@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
...
@@ -404,6 +453,30 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
}
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
conv_ptrs
);
}
#endif
#endif
}
}
...
...
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
f9cf57d4
...
@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -199,8 +199,6 @@ struct BlockwiseGemmAvx2_MxN
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldb
=
GetBLeadingElement
(
b_block_desc
)
*
sizeof
(
FloatB
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
auto
ldc
=
GetCLeadingElement
(
c_desc
)
*
sizeof
(
FloatC
);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const
auto
k_per_block
=
a_slice_length
[
Number
<
1
>
{}];
const
auto
k_per_block
=
a_slice_length
[
Number
<
1
>
{}];
const
auto
m_per_block
=
c_slice_length
[
Number
<
0
>
{}];
const
auto
m_per_block
=
c_slice_length
[
Number
<
0
>
{}];
const
auto
n_per_block
=
c_slice_length
[
Number
<
1
>
{}];
const
auto
n_per_block
=
c_slice_length
[
Number
<
1
>
{}];
...
@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -215,8 +213,16 @@ struct BlockwiseGemmAvx2_MxN
param
.
alpha
=
1.0
f
;
// TODO
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc,
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u, mpt:%u, npt:%u\n",
// m_per_block, n_per_block, k_per_block);
// lda,
// ldb,
// ldc,
// m_per_block,
// n_per_block,
// k_per_block,
// m_per_thread,
// n_per_thread);
// fflush(stdout);
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
{
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
0 → 100644
View file @
f9cf57d4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
0 → 100644
View file @
f9cf57d4
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp
View file @
f9cf57d4
...
@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -81,12 +81,8 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
".if m_TransA != 0
\n
"
"movq 32(%[m_param]), %%rcx
\n
"
// lda
"movq 32(%[m_param]), %%rcx
\n
"
// lda
".endif
\n
"
".if m_TransB == 0
\n
"
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
".endif
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
".if
\\
i_scale != 0
\n
"
...
@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -120,10 +116,14 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8
, r9
), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_ABytes == 4
\n
"
".if m_ABytes == 4
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx, (
\\
i_k-3), (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
...
@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -133,7 +133,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vpbroadcastw_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes), %%xmm15
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes), %%xmm15
\n
"
".else
\n
"
"vpbroadcastw_%= %%rax, %%rcx, (
\\
i_k-3), (
\\
i_m * m_ABytes), %%xmm15
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1) || (
\\
i_m == 2)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
...
@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -145,18 +149,26 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx, lda in rdx, i_n should be 0, 1
".macro vload_b%= i_k, i_n, ymm
\n
"
// B in rbx
(r9)
, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4
\n
"
".if m_BBytes == 4
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
"vmovups_%= %%r9, %%rdx, (
\\
i_k-3), (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vcvtph2ps_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1) || (
\\
i_k == 2)
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
"vcvtph2ps_%= %%r9, %%rdx, (
\\
i_k-3), (
\\
i_n*m_BBytes*8),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -179,6 +191,13 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
"lea (%%rcx, %%rcx, 2), %%r9
\n
"
"lea (%%rax, %%r9), %%r8
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
"lea (%%rdx, %%rdx, 2), %%rdi
\n
"
"lea (%%rbx, %%rdi), %%r9
\n
"
".endif
\n
"
".endif
\n
"
"cmp $4, %%rsi
\n
"
"cmp $4, %%rsi
\n
"
...
@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -214,10 +233,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea 4*m_ABytes(%%rax), %%rax
\n
"
" lea 4*m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 4), %%rax
\n
"
" lea (%%r8, %%rcx, 4), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 4), %%rbx
\n
"
" lea (%%r9, %%rdx, 4), %%r9
\n
"
".else
\n
"
".else
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -256,10 +277,12 @@ struct ThreadwiseGemmAvx2_MxN_6x16
" lea m_ABytes(%%rax), %%rax
\n
"
" lea m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 1), %%rax
\n
"
" lea (%%r8, %%rcx, 1), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 1), %%rbx
\n
"
" lea (%%r9, %%rdx, 1), %%r9
\n
"
".else
\n
"
".else
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -381,7 +404,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
else
else
{
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
Mr
+
i_m
);
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
lda
+
i_m
);
}
}
}
}
else
else
...
@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -396,7 +419,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
}
}
else
else
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
Mr
+
i_m
)));
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
lda
+
i_m
)));
}
}
}
}
};
};
...
@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -406,7 +429,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
{
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
Nr
+
i_n
*
8
);
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
ldb
+
i_n
*
8
);
}
}
else
else
{
{
...
@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -418,7 +441,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
Nr
+
i_n
*
8
)));
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
ldb
+
i_n
*
8
)));
}
}
else
else
{
{
...
@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -488,10 +511,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
4
;
p_a
+=
4
;
}
else
{
}
else
{
p_a
+=
Mr
*
4
;
p_a
+=
lda
*
4
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
4
;
p_b
+=
ldb
*
4
;
}
else
{
}
else
{
p_b
+=
4
*
8
;
p_b
+=
4
*
8
;
}
}
...
@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
...
@@ -525,10 +548,10 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
1
;
p_a
+=
1
;
}
else
{
}
else
{
p_a
+=
Mr
*
1
;
p_a
+=
lda
*
1
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
1
;
p_b
+=
ldb
*
1
;
}
else
{
}
else
{
p_b
+=
1
*
8
;
p_b
+=
1
*
8
;
}
}
...
@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -641,12 +664,8 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq (%[m_param]), %%rax
\n
"
// p_a
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 8(%[m_param]), %%rbx
\n
"
// p_b
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
"movq 24(%[m_param]), %%rsi
\n
"
// Kr
".if m_TransA != 0
\n
"
"movq 32(%[m_param]), %%rcx
\n
"
// lda
"movq 32(%[m_param]), %%rcx
\n
"
// lda
".endif
\n
"
".if m_TransB == 0
\n
"
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
"movq 40(%[m_param]), %%rdx
\n
"
// ldb
".endif
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm
\n
"
".if
\\
i_scale != 0
\n
"
".if
\\
i_scale != 0
\n
"
...
@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -683,7 +702,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".macro vbroadcast_a%= i_k, i_m, ymm
\n
"
// A in rax(r8), lda in rcx
".if m_ABytes == 4
\n
"
".if m_ABytes == 4
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vbroadcastss_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".else
\n
"
"vbroadcastss_%= %%r8, %%rcx, (
\\
i_k-2), (
\\
i_m * m_ABytes),
\\
ymm
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
"vbroadcastss_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes),
\\
ymm
\n
"
...
@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -693,7 +716,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransA == 0
\n
"
".if m_TransA == 0
\n
"
"vpbroadcastw_%= %%rax, 0, 0, ((
\\
i_m +
\\
i_k * m_Mr) * m_ABytes), %%xmm15
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_k, (
\\
i_m * m_ABytes), %%xmm15
\n
"
".else
\n
"
"vpbroadcastw_%= %%r8, %%rcx, (
\\
i_k-2), (
\\
i_m * m_ABytes), %%xmm15
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
".if (
\\
i_m == 0) || (
\\
i_m == 1)
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
"vpbroadcastw_%= %%rax, %%rcx,
\\
i_m, (
\\
i_k * m_ABytes), %%xmm15
\n
"
...
@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -710,13 +737,21 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vmovups_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vmovups_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".else
\n
"
"vmovups_%= %%rdi, %%rdx, (
\\
i_k-2), (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
".else
\n
"
".if m_TransB == 0
\n
"
".if m_TransB == 0
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_n, (
\\
i_k*m_BBytes*8),
\\
ymm
\n
"
".else
\n
"
".else
\n
"
"vcvtph2ps_%= %%rbx, 0, 0, ((
\\
i_k*m_Nr +
\\
i_n*8)*m_BBytes),
\\
ymm
\n
"
".if (
\\
i_k == 0) || (
\\
i_k == 1)
\n
"
"vcvtph2ps_%= %%rbx, %%rdx,
\\
i_k, (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".else
\n
"
"vcvtph2ps_%= %%rdi, %%rdx, (
\\
i_k-2), (
\\
i_n*8*m_BBytes),
\\
ymm
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endif
\n
"
".endm
\n
"
".endm
\n
"
...
@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -738,6 +773,11 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if m_Mr > 2
\n
"
".if m_Mr > 2
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".else
\n
"
"lea (%%rax, %%rcx, 2), %%r8
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
"lea (%%rbx, %%rdx, 2), %%rdi
\n
"
".endif
\n
"
".endif
\n
"
"cmp $4, %%rsi
\n
"
"cmp $4, %%rsi
\n
"
...
@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -773,10 +813,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea 4*m_ABytes(%%rax), %%rax
\n
"
" lea 4*m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 2
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 2
\n
lea 4*m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * 4 * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 4), %%rax
\n
"
" lea (%%r8, %%rcx, 4), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 4), %%rbx
\n
"
" lea (%%rdi, %%rdx, 4), %%rdi
\n
"
".else
\n
"
".else
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
" lea 8 * 4 * m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -815,10 +857,12 @@ struct ThreadwiseGemmAvx2_MxN_4x24
" lea m_ABytes(%%rax), %%rax
\n
"
" lea m_ABytes(%%rax), %%rax
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".if m_Mr > 3
\n
lea m_ABytes(%%r8), %%r8
\n
.endif
\n
"
".else
\n
"
".else
\n
"
" lea m_Mr * m_ABytes(%%rax), %%rax
\n
"
" lea (%%rax, %%rcx, 1), %%rax
\n
"
" lea (%%r8, %%rcx, 1), %%r8
\n
"
".endif
\n
"
".endif
\n
"
".if m_TransB != 0
\n
"
".if m_TransB != 0
\n
"
" lea m_Nr * m_BBytes(%%rbx), %%rbx
\n
"
" lea (%%rbx, %%rdx, 1), %%rbx
\n
"
" lea (%%rdi, %%rdx, 1), %%rdi
\n
"
".else
\n
"
".else
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
" lea 8*m_BBytes(%%rbx), %%rbx
\n
"
".endif
\n
"
".endif
\n
"
...
@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -937,7 +981,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}
}
else
else
{
{
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
Mr
+
i_m
);
ymm
=
_mm256_broadcast_ss
(
p_a
+
i_k
*
lda
+
i_m
);
}
}
}
}
else
else
...
@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -952,7 +996,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
}
}
else
else
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
Mr
+
i_m
)));
ymm
=
_mm256_cvtph_ps
(
_mm_set1_epi16
(
*
(
p_a
+
i_k
*
lda
+
i_m
)));
}
}
}
}
};
};
...
@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -962,7 +1006,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
{
{
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
Nr
+
i_n
*
8
);
ymm
=
_mm256_loadu_ps
(
p_b
+
i_k
*
ldb
+
i_n
*
8
);
}
}
else
else
{
{
...
@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -974,7 +1018,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
ymm
=
_mm256_cvtph_ps
(
_mm_loadu_si128
(
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
Nr
+
i_n
*
8
)));
reinterpret_cast
<
__m128i
const
*>
(
p_b
+
i_k
*
ldb
+
i_n
*
8
)));
}
}
else
else
{
{
...
@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -1044,10 +1088,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
4
;
p_a
+=
4
;
}
else
{
}
else
{
p_a
+=
Mr
*
4
;
p_a
+=
lda
*
4
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
4
;
p_b
+=
ldb
*
4
;
}
else
{
}
else
{
p_b
+=
4
*
8
;
p_b
+=
4
*
8
;
}
}
...
@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
...
@@ -1081,10 +1125,10 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
){
p_a
+=
1
;
p_a
+=
1
;
}
else
{
}
else
{
p_a
+=
Mr
*
1
;
p_a
+=
lda
*
1
;
}
}
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
if
constexpr
(
std
::
is_same
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
){
p_b
+=
Nr
*
1
;
p_b
+=
ldb
*
1
;
}
else
{
}
else
{
p_b
+=
1
*
8
;
p_b
+=
1
*
8
;
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
f9cf57d4
...
@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
...
@@ -1277,6 +1277,138 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
intptr_t
src_offset
;
intptr_t
src_offset
;
};
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_YXCK
(
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
DstDesc
&
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
idx_k
*
GemmN
+
idx_n
;
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
k_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// k * n
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
n_per_block
,
p_src
+
1
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
n_per_block
,
p_src
+
2
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
n_per_block
,
p_src
+
3
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
n_per_block
,
p_src
+
4
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
n_per_block
,
p_src
+
5
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
n_per_block
,
p_src
+
6
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
n_per_block
,
p_src
+
7
*
GemmN
,
n_per_block
,
element_op_
);
i_k_itr
-=
8
;
p_dst
+=
8
*
n_per_block
;
p_src
+=
8
*
GemmN
;
}
if
(
i_k_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
n_per_block
,
p_src
+
1
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
n_per_block
,
p_src
+
2
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
n_per_block
,
p_src
+
3
*
GemmN
,
n_per_block
,
element_op_
);
p_dst
+=
4
*
n_per_block
;
p_src
+=
4
*
GemmN
;
}
if
(
i_k_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
n_per_block
,
p_src
+
1
*
GemmN
,
n_per_block
,
element_op_
);
p_dst
+=
2
*
n_per_block
;
p_src
+=
2
*
GemmN
;
}
if
(
i_k_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
n_per_block
,
p_src
+
0
*
GemmN
,
n_per_block
,
element_op_
);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
ck
::
index_t
move_n
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
src_offset
+=
move_k
*
GemmN
+
move_n
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
View file @
f9cf57d4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
)
)
add_library
(
device_conv2d_fwd_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_cpu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_cpu_instance PUBLIC
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
0 → 100644
View file @
f9cf57d4
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
using
InType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
,
\
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances
{});
}
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/CMakeLists.txt
View file @
f9cf57d4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
)
)
add_library
(
device_conv2d_fwd_bias_activation_add_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_activation_add_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_activation_add_cpu_instance PUBLIC
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
0 → 100644
View file @
f9cf57d4
#include <stdlib.h>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_add_avx2_instance
{
using
InType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
,
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
test/cpu_ukernel/cpu_gemm_uk.cpp
View file @
f9cf57d4
...
@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
...
@@ -233,68 +233,30 @@ void test_ukernel(ukenrel_t uk,
int
max_threads
=
omp_get_max_threads
();
int
max_threads
=
omp_get_max_threads
();
auto
invoke_uk
=
[
&
](
ck
::
cpu
::
ThreadwiseGemmParam
&
param
,
float
*
current_mat_c
)
{
auto
invoke_uk
=
[
&
](
ck
::
cpu
::
ThreadwiseGemmParam
&
param
,
float
*
current_mat_c
)
{
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
%
uk
.
ThreadNr
==
0
);
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
ThreadMr
)
{
{
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
==
uk
.
ThreadNr
);
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
)
FloatA
*
p_a
=
mat_a
;
float
*
p_c
=
current_mat_c
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
ThreadMr
)
{
{
uk
.
Run
(
&
param
);
param
.
p_a
=
mat_a
+
i_m
*
k
;
p_a
+=
uk
.
ThreadMr
*
k
;
p_c
+=
uk
.
ThreadMr
*
n
;
param
.
p_a
=
p_a
;
param
.
p_c
=
p_c
;
}
}
}
else
else
if
constexpr
(
std
::
is_same
<
Row
,
ALayout
>::
value
&&
std
::
is_same
<
Col
,
BLayout
>::
value
)
{
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
%
uk
.
ThreadNr
==
0
);
FloatA
*
p_a
=
mat_a
;
float
*
p_c
=
current_mat_c
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_m
=
0
;
i_m
<
m
;
i_m
+=
uk
.
ThreadMr
)
{
{
float
*
p_c_n
=
p_c
;
param
.
p_a
=
mat_a
+
i_m
;
FloatB
*
p_b_n
=
mat_b
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
ThreadNr
)
{
uk
.
Run
(
&
param
);
p_b_n
+=
uk
.
ThreadNr
*
k
;
// ThreadNr/8*k*8
p_c_n
+=
uk
.
ThreadNr
;
param
.
p_b
=
p_b_n
;
param
.
p_c
=
p_c_n
;
}
p_a
+=
uk
.
ThreadMr
*
k
;
p_c
+=
uk
.
ThreadMr
*
n
;
param
.
p_a
=
p_a
;
param
.
p_b
=
mat_b
;
param
.
p_c
=
p_c
;
}
}
}
else
if
constexpr
(
std
::
is_same
<
Col
,
ALayout
>::
value
&&
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
assert
(
m
==
uk
.
ThreadMr
&&
n
==
uk
.
ThreadNr
);
uk
.
Run
(
&
param
);
}
else
{
assert
(
m
%
uk
.
ThreadMr
==
0
&&
n
%
uk
.
ThreadNr
==
0
);
FloatB
*
p_b
=
mat_b
;
float
*
p_c
=
current_mat_c
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
ThreadNr
)
for
(
uint32_t
i_n
=
0
;
i_n
<
n
;
i_n
+=
uk
.
ThreadNr
)
{
{
if
constexpr
(
std
::
is_same
<
Row
,
BLayout
>::
value
)
{
param
.
p_b
=
mat_b
+
i_n
;
}
else
{
param
.
p_b
=
mat_b
+
i_n
*
k
;
}
param
.
p_c
=
current_mat_c
+
i_m
*
n
+
i_n
;
uk
.
Run
(
&
param
);
uk
.
Run
(
&
param
);
p_b
+=
uk
.
ThreadNr
*
k
;
// ThreadNr/8*k*8
p_c
+=
uk
.
ThreadNr
;
param
.
p_b
=
p_b
;
param
.
p_c
=
p_c
;
}
}
}
}
};
};
...
@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
...
@@ -358,7 +320,11 @@ void test_ukernel(ukenrel_t uk,
}
}
// implement small ukernel on L1
// implement small ukernel on L1
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
ALayout
,
typename
BLayout
,
typename
thread_gemm_instance
>
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
void
test_cpu_ukernel
(
float
alpha
,
uint32_t
m
,
uint32_t
n
,
uint32_t
k
)
{
{
int
max_threads
=
omp_get_max_threads
();
int
max_threads
=
omp_get_max_threads
();
...
@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
...
@@ -382,17 +348,18 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
k
);
k
);
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
// using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
using
thread_gemm_instance
=
thread_gemm_avx2_mxn_4x24_instances
<
ALayout
,
BLayout
>
;
//
using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool
found
=
false
;
bool
found
=
false
;
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
thread_gemm_instance
>
,
1
>
{}([
&
](
auto
i
)
{
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
using
uk_type
=
std
::
tuple_element_t
<
i
,
thread_gemm_instance
>
;
if
(
m
%
uk_type
::
ThreadMr
!=
0
||
n
%
uk_type
::
ThreadNr
!=
0
)
if
(
m
%
uk_type
::
ThreadMr
!=
0
||
n
%
uk_type
::
ThreadNr
!=
0
)
return
;
return
;
if
((
m
!=
uk_type
::
ThreadMr
&&
std
::
is_same
<
typename
uk_type
::
MatrixALayout
,
Col
>::
value
)
||
// if((m != uk_type::ThreadMr && std::is_same<typename uk_type::MatrixALayout, Col>::value)
(
n
!=
uk_type
::
ThreadNr
&&
std
::
is_same
<
typename
uk_type
::
MatrixBLayout
,
Row
>::
value
))
// ||
// only k is the fast changing dim of A/B can we do muldiplt m, n
// (n != uk_type::ThreadNr && std::is_same<typename uk_type::MatrixBLayout, Row>::value))
return
;
// // only k is the fast changing dim of A/B can we do muldiplt m, n
// return;
if
(
found
)
if
(
found
)
return
;
return
;
...
@@ -435,8 +402,21 @@ int main(int argc, char** argv)
...
@@ -435,8 +402,21 @@ int main(int argc, char** argv)
omp_set_num_threads
(
1
);
omp_set_num_threads
(
1
);
printf
(
"max threads:%d
\n
"
,
omp_get_max_threads
());
printf
(
"max threads:%d
\n
"
,
omp_get_max_threads
());
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
,
thread_gemm_avx2_mxn_4x24_instances
<
Row
,
Row
>>
(
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
>
(
alpha
,
m
,
n
,
k
);
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
,
thread_gemm_avx2_mxn_4x24_instances
<
Row
,
Col
>>
(
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
>
(
alpha
,
m
,
n
,
k
);
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
,
thread_gemm_avx2_mxn_4x24_instances
<
Col
,
Row
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
,
thread_gemm_avx2_mxn_4x24_instances
<
Col
,
Col
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Row
,
thread_gemm_avx2_mxn_6x16_instances
<
Row
,
Row
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Row
,
Col
,
thread_gemm_avx2_mxn_6x16_instances
<
Row
,
Col
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Row
,
thread_gemm_avx2_mxn_6x16_instances
<
Col
,
Row
>>
(
alpha
,
m
,
n
,
k
);
test_cpu_ukernel
<
AType
,
BType
,
Col
,
Col
,
thread_gemm_avx2_mxn_6x16_instances
<
Col
,
Col
>>
(
alpha
,
m
,
n
,
k
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment