Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ad09ebdb
Commit
ad09ebdb
authored
May 17, 2022
by
carlushuang
Browse files
add kyxck8
parent
d6d37ea9
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
2335 additions
and
927 deletions
+2335
-927
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+919
-919
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
...on/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
+899
-0
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+194
-6
library/src/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
...c/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
..._fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+206
-0
profiler/include/profile_conv_fwd_cpu_impl.hpp
profiler/include/profile_conv_fwd_cpu_impl.hpp
+18
-0
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+98
-2
No files found.
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
ad09ebdb
...
@@ -569,7 +569,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -569,7 +569,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GemmKSpecialization
>
;
GemmKSpecialization
>
;
using
BThreadwiseCopy
=
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_
NHW
C
<
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_
KYX
C
<
WeiDataType
,
WeiDataType
,
WeiDataType
,
WeiDataType
,
BGridDesc
,
BGridDesc
,
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
0 → 100644
View file @
ad09ebdb
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
ad09ebdb
...
@@ -484,8 +484,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -484,8 +484,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
while
(
i_m_itr
>
0
)
while
(
i_m_itr
>
0
)
{
{
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
*
reinterpret_cast
<
uint32_t
*>
(
&
Hi
))
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
*
reinterpret_cast
<
uint32_t
*>
(
&
Wi
)))
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
,
element_op_
);
else
else
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
...
@@ -543,8 +545,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -543,8 +545,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// printf("current_k_block_along_c:%d, i_c_itr_k:%d, k_per_block:%d\n",
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
// current_k_block_along_c, i_c_itr_k,k_per_block); fflush(stdout);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
*
reinterpret_cast
<
uint32_t
*>
(
&
Hi
))
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
*
reinterpret_cast
<
uint32_t
*>
(
&
Wi
)))
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block_along_c
,
element_op_
);
p_dst_k
,
p_src_k
,
current_k_block_along_c
,
element_op_
);
else
else
...
@@ -715,7 +719,7 @@ template <typename SrcData,
...
@@ -715,7 +719,7 @@ template <typename SrcData,
bool
BypassTransfer
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_
NHW
C
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_
KYX
C
{
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
...
@@ -723,7 +727,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
...
@@ -723,7 +727,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_
NHW
C
(
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_
KYX
C
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
...
@@ -927,6 +931,190 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
...
@@ -927,6 +931,190 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
intptr_t
src_offset
;
intptr_t
src_offset
;
};
};
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
{
GemmN1
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
// Need to be 8
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
ck
::
index_t
idx_n0
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_n1
=
src_slice_origin_idx
[
Number
<
2
>
{}];
src_offset
=
idx_n0
*
GemmK
*
GemmN1
+
idx_k
*
GemmN1
+
idx_n1
;
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{}
else
{
const
ck
::
index_t
n0_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
k_n1_per_block
=
slice_length
[
Number
<
1
>
{}]
*
slice_length
[
Number
<
2
>
{}];
const
ck
::
index_t
SrcStride_K_N1
=
GemmK
*
slice_length
[
Number
<
2
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n0 * k * n1
index_t
i_n0_itr
=
n0_per_block
;
while
(
i_n0_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_n1_per_block
,
p_src
+
0
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_n1_per_block
,
p_src
+
1
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_n1_per_block
,
p_src
+
2
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_n1_per_block
,
p_src
+
3
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_n1_per_block
,
p_src
+
4
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_n1_per_block
,
p_src
+
5
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_n1_per_block
,
p_src
+
6
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_n1_per_block
,
p_src
+
7
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
i_n0_itr
-=
8
;
p_dst
+=
8
*
k_n1_per_block
;
p_src
+=
8
*
SrcStride_K_N1
;
}
if
(
i_n0_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_n1_per_block
,
p_src
+
0
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_n1_per_block
,
p_src
+
1
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_n1_per_block
,
p_src
+
2
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_n1_per_block
,
p_src
+
3
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
p_dst
+=
4
*
k_n1_per_block
;
p_src
+=
4
*
SrcStride_K_N1
;
}
if
(
i_n0_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_n1_per_block
,
p_src
+
0
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_n1_per_block
,
p_src
+
1
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
p_dst
+=
2
*
k_n1_per_block
;
p_src
+=
2
*
SrcStride_K_N1
;
}
if
(
i_n0_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_n1_per_block
,
p_src
+
0
*
SrcStride_K_N1
,
k_n1_per_block
,
element_op_
);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
ck
::
index_t
move_n0
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
ck
::
index_t
move_n1
=
src_slice_origin_step_idx
[
Number
<
2
>
{}];
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset
+=
move_n0
*
GemmK
*
GemmN1
+
move_k
*
GemmN1
+
move_n1
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_gemm_n
;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck
::
index_t
GemmN1
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/CMakeLists.txt
View file @
ad09ebdb
# device_conv2d_fwd_cpu_instance
# device_conv2d_fwd_cpu_instance
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
)
)
add_library
(
device_conv2d_fwd_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_cpu_instance SHARED
${
DEVICE_CONV2D_FWD_CPU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_cpu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_cpu_instance PUBLIC
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
0 → 100644
View file @
ad09ebdb
#include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
using
InType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
using
InLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// NHWC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXCK8
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances
{});
}
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profile_conv_fwd_cpu_impl.hpp
View file @
ad09ebdb
...
@@ -33,6 +33,24 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
...
@@ -33,6 +33,24 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
ad09ebdb
...
@@ -16,7 +16,11 @@
...
@@ -16,7 +16,11 @@
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
#define TEST_FUSION TEST_FUSION_PASSTHROUGH
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT TEST_LAYOUT_NHWC_KYXCK8_NHWK
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -48,6 +52,24 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
...
@@ -48,6 +52,24 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -115,6 +137,31 @@ check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pi
...
@@ -115,6 +137,31 @@ check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pi
float
calculate_gflops
()
{}
float
calculate_gflops
()
{}
template
<
typename
T
>
void
transpose_kyxc_2_kyxc8k
(
Tensor
<
T
>&
dst
,
const
Tensor
<
T
>&
src
,
ck
::
index_t
K
,
ck
::
index_t
Y
,
ck
::
index_t
X
,
ck
::
index_t
C
)
{
ck
::
index_t
batch
=
K
/
8
;
ck
::
index_t
row
=
8
;
ck
::
index_t
col
=
C
*
Y
*
X
;
for
(
auto
i_b
=
0
;
i_b
<
batch
;
i_b
++
)
{
for
(
auto
i_r
=
0
;
i_r
<
row
;
i_r
++
)
{
for
(
auto
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
ck
::
index_t
src_idx
=
i_b
*
row
*
col
+
i_r
*
col
+
i_c
;
ck
::
index_t
dst_idx
=
i_b
*
col
*
row
+
i_c
*
row
+
i_r
;
dst
.
mData
[
dst_idx
]
=
src
.
mData
[
src_idx
];
}
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
int
data_type
=
0
;
int
data_type
=
0
;
...
@@ -213,6 +260,10 @@ int main(int argc, char* argv[])
...
@@ -213,6 +260,10 @@ int main(int argc, char* argv[])
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
));
Tensor
<
WeiDataType
>
wei_k_c_y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
Tensor
<
WeiDataType
>
wei_k_c_y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
Tensor
<
WeiDataType
>
wei_k_c_y_x_k8
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
));
// TODO: This is only to hold data
#endif
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
));
...
@@ -296,8 +347,13 @@ int main(int argc, char* argv[])
...
@@ -296,8 +347,13 @@ int main(int argc, char* argv[])
AVX2_DATA_ALIGNMENT
);
AVX2_DATA_ALIGNMENT
);
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
transpose_kyxc_2_kyxc8k
(
wei_k_c_y_x_k8
,
wei_k_c_y_x
,
K
,
Y
,
X
,
C
);
wei_device_buf
.
ToDevice
(
wei_k_c_y_x_k8
.
mData
.
data
());
#endif
// get host result
// get host result
{
{
auto
ref_conv
=
ReferenceConvFwdInstance
{};
auto
ref_conv
=
ReferenceConvFwdInstance
{};
...
@@ -334,6 +390,7 @@ int main(int argc, char* argv[])
...
@@ -334,6 +390,7 @@ int main(int argc, char* argv[])
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
{
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXC_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if
(
omp_get_max_threads
()
>
1
)
if
(
omp_get_max_threads
()
>
1
)
{
{
...
@@ -369,6 +426,45 @@ int main(int argc, char* argv[])
...
@@ -369,6 +426,45 @@ int main(int argc, char* argv[])
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
conv_ptrs
);
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
conv_ptrs
);
}
}
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_KYXCK8_NHWK
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
conv_ptrs
);
}
#endif
#endif
#endif
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment