Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9cefc261
Commit
9cefc261
authored
Jun 13, 2022
by
carlushuang
Browse files
refactor device instance to use less template, more dynamic tunable
parent
6dfb4e78
Changes
18
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1006 additions
and
1016 deletions
+1006
-1016
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
+1
-1
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+1
-1
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+5
-4
include/ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp
...de/ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp
+19
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+61
-63
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
...on/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
+28
-30
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
+24
-28
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
+55
-63
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
..._convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
+55
-63
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
+55
-60
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+37
-25
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
...ation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
+41
-29
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+134
-137
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
..._fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+137
-144
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
+139
-144
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
+72
-73
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
...2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
+71
-76
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
+71
-75
No files found.
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
View file @
9cefc261
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
#define TEST_LAYOUT TEST_LAYOUT_NHWC_
K
YXCK
8
_NHWK
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
...
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
9cefc261
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
#define TEST_LAYOUT TEST_LAYOUT_NHWC_
K
YXCK
8
_NHWK
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
...
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
9cefc261
...
@@ -18,8 +18,6 @@ template <typename FloatA,
...
@@ -18,8 +18,6 @@ template <typename FloatA,
typename
BBlockDesc
,
typename
BBlockDesc
,
typename
CDesc
,
typename
CDesc
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadMNAccessOrder
// how we acces gemm MN to utilize micro kernel
typename
ThreadMNAccessOrder
// how we acces gemm MN to utilize micro kernel
>
>
...
@@ -83,8 +81,11 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -83,8 +81,11 @@ struct BlockwiseGemmAvx2_MxN
else
else
{
{
// N/8 * K * 8
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
*
// return b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
// 1>{}].GetUpperLengths()[Number<1>{}] *
// b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
// 1>{}].GetUpperLengths()[Number<2>{}];
return
b_block_desc
.
GetLength
(
Number
<
1
>
{})
*
b_block_desc
.
GetLength
(
Number
<
2
>
{});
}
}
}
}
...
...
include/ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp
View file @
9cefc261
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <iostream>
#include <iostream>
#include "device_base_cpu.hpp"
#include "device_base_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -77,6 +78,24 @@ using DeviceConvFwdBiasActivationAddPtr =
...
@@ -77,6 +78,24 @@ using DeviceConvFwdBiasActivationAddPtr =
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
OutElementwiseOperation
>>
;
struct
DeviceConvFwdDynamicTunable
{
ck
::
index_t
m_per_block
;
ck
::
index_t
n_per_block
;
ck
::
index_t
k_per_block
;
// ck::index_t m_per_thread;
// ck::index_t n_per_thread;
// bool use_a_local_buffer;
// bool use_b_local_buffer;
// bool use_c_local_buffer;
// ConvolutionForwardSpecialization_t forward_spec;
// ConvolutionForwardGemmKSpecialization_t gemm_k_spec;
ConvolutionForwardBlockLoopOverSpecialization_t
loop_over_spec
;
};
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
9cefc261
...
@@ -30,11 +30,7 @@ template <typename InDataType,
...
@@ -30,11 +30,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
...
@@ -65,17 +61,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -65,17 +61,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
@@ -106,45 +97,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -106,45 +97,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
ck
::
index_t
gemm_n_padded
=
ck
::
index_t
gemm_n_padded
=
...
@@ -576,6 +528,48 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -576,6 +528,48 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
using
AThreadwiseCopy
=
...
@@ -620,20 +614,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -620,20 +614,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AElementwiseOperation
,
// AElementwiseOperation,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -711,11 +703,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -711,11 +703,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
int
nrepeat
=
1
)
{
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
...
@@ -738,6 +734,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -738,6 +734,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if
(
nrepeat
!=
1
)
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -753,6 +750,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -753,6 +750,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -780,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -780,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
@@ -811,7 +809,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -811,7 +809,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ConvForwardSpecialization
!=
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
return
false
;
}
}
...
@@ -825,7 +823,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -825,7 +823,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -868,7 +866,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -868,7 +866,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_element_op
};
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
@@ -908,7 +906,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -908,7 +906,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
}
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
...
@@ -925,8 +923,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -925,8 +923,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
"DFwdAvx2_NHWC_KYXC"
<<
"DFwdAvx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
View file @
9cefc261
...
@@ -30,11 +30,8 @@ template <typename InDataType,
...
@@ -30,11 +30,8 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
//
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
...
@@ -65,17 +62,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -65,17 +62,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
@@ -518,7 +510,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -518,7 +510,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
{
if
constexpr
(
UseALocalBuffer
)
if
constexpr
(
UseALocalBuffer
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
}
else
else
{
{
...
@@ -530,10 +523,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -530,10 +523,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
{
if
constexpr
(
UseBLocalBuffer
)
if
constexpr
(
UseBLocalBuffer
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
// return make_naive_tensor_descriptor_packed(make_tuple(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
// math::integer_divide_ceil(NPerBlock,
KPerBlock
,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
}
else
else
{
{
...
@@ -545,7 +539,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -545,7 +539,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
{
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
}
else
else
{
{
...
@@ -597,20 +592,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -597,20 +592,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
AElementwiseOperation
,
// AElementwiseOperation,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -687,12 +680,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -687,12 +680,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
int
nrepeat
=
1
)
{
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
...
@@ -715,6 +711,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -715,6 +711,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
if
(
nrepeat
!=
1
)
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -730,6 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -730,6 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -757,7 +755,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -757,7 +755,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
@@ -788,7 +786,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -788,7 +786,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
ConvForwardSpecialization
!=
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
return
false
;
}
}
...
@@ -805,7 +803,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -805,7 +803,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -848,7 +846,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -848,7 +846,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
out_element_op
};
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
@@ -888,7 +886,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -888,7 +886,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
}
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
...
@@ -905,8 +903,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
...
@@ -905,8 +903,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<<
"DFwdAvx2_NHWC_KYXCK8"
<<
"DFwdAvx2_NHWC_KYXCK8"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
View file @
9cefc261
...
@@ -29,11 +29,7 @@ template <typename InDataType,
...
@@ -29,11 +29,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
...
@@ -64,17 +60,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -64,17 +60,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
@@ -514,7 +505,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -514,7 +505,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
{
if
constexpr
(
UseALocalBuffer
)
if
constexpr
(
UseALocalBuffer
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
}
else
else
{
{
...
@@ -526,7 +518,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -526,7 +518,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
{
if
constexpr
(
UseBLocalBuffer
)
if
constexpr
(
UseBLocalBuffer
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
NPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
}
else
else
{
{
...
@@ -538,7 +531,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -538,7 +531,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
{
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
}
else
else
{
{
...
@@ -590,20 +584,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -590,20 +584,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
AElementwiseOperation
,
// AElementwiseOperation,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -680,12 +672,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -680,12 +672,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
int
nrepeat
=
1
)
{
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
...
@@ -708,6 +703,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -708,6 +703,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
if
(
nrepeat
!=
1
)
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -723,6 +719,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -723,6 +719,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -750,7 +747,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -750,7 +747,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
@@ -781,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -781,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
ConvForwardSpecialization
!=
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
return
false
;
}
}
...
@@ -801,7 +798,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -801,7 +798,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -844,7 +841,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -844,7 +841,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
out_element_op
};
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
@@ -884,7 +881,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -884,7 +881,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
}
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
...
@@ -901,9 +898,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
...
@@ -901,9 +898,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<<
"DFwdAvx2_NHWC_YXCK"
<<
"DFwdAvx2_NHWC_YXCK"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_over_spec
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_block
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
View file @
9cefc261
...
@@ -32,11 +32,7 @@ template <typename InDataType,
...
@@ -32,11 +32,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
...
@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
ck
::
index_t
gemm_n_padded
=
ck
::
index_t
gemm_n_padded
=
...
@@ -598,6 +550,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -598,6 +550,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
using
AThreadwiseCopy
=
...
@@ -650,20 +638,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -650,20 +638,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
AElementwiseOperation
,
// AElementwiseOperation,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -755,11 +741,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -755,11 +741,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
int
nrepeat
=
1
)
{
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
...
@@ -787,6 +777,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -787,6 +777,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
if
(
nrepeat
!=
1
)
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -806,6 +797,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -806,6 +797,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -837,7 +829,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -837,7 +829,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
@@ -868,7 +860,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -868,7 +860,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
ConvForwardSpecialization
!=
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
return
false
;
}
}
...
@@ -882,7 +874,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -882,7 +874,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -929,7 +921,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -929,7 +921,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
out_element_op
};
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
@@ -973,7 +965,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -973,7 +965,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
}
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
...
@@ -990,8 +982,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
...
@@ -990,8 +982,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
<<
"DFwd_BAA_Avx2_NHWC_KYXC"
<<
"DFwd_BAA_Avx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
View file @
9cefc261
...
@@ -32,11 +32,7 @@ template <typename InDataType,
...
@@ -32,11 +32,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
...
@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
...
@@ -575,6 +527,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -575,6 +527,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
using
AThreadwiseCopy
=
...
@@ -627,20 +615,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -627,20 +615,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AElementwiseOperation
,
// AElementwiseOperation,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -732,11 +718,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -732,11 +718,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
int
nrepeat
=
1
)
{
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
...
@@ -764,6 +754,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -764,6 +754,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if
(
nrepeat
!=
1
)
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -783,6 +774,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -783,6 +774,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -814,7 +806,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -814,7 +806,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
@@ -845,7 +837,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -845,7 +837,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
ConvForwardSpecialization
!=
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
return
false
;
}
}
...
@@ -862,7 +854,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -862,7 +854,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -909,7 +901,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -909,7 +901,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
out_element_op
};
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
@@ -953,7 +945,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -953,7 +945,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
}
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
...
@@ -970,8 +962,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -970,8 +962,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
<<
"DFwd_BAA_Avx2_NHWC_KYXCK8"
<<
"DFwd_BAA_Avx2_NHWC_KYXCK8"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
View file @
9cefc261
...
@@ -31,11 +31,7 @@ template <typename InDataType,
...
@@ -31,11 +31,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
...
@@ -72,17 +68,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -72,17 +68,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
@@ -111,42 +102,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -111,42 +102,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
NPerBlock
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_k
,
gemm_n
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_k
,
gemm_n
));
...
@@ -568,6 +523,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -568,6 +523,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
using
AThreadwiseCopy
=
...
@@ -620,20 +611,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -620,20 +611,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
AElementwiseOperation
,
// AElementwiseOperation,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -725,11 +714,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -725,11 +714,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
int
nrepeat
=
1
)
{
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
}
...
@@ -757,6 +750,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -757,6 +750,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
if
(
nrepeat
!=
1
)
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -776,6 +770,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -776,6 +770,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -807,7 +802,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -807,7 +802,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
@@ -838,7 +833,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -838,7 +833,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
ConvForwardSpecialization
!=
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
return
false
;
}
}
...
@@ -858,7 +853,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -858,7 +853,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -905,7 +900,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -905,7 +900,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
out_element_op
};
out_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
@@ -949,7 +944,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -949,7 +944,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
}
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
...
@@ -966,8 +961,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
...
@@ -966,8 +961,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
<<
"DFwd_BAA_Avx2_NHWC_YXCK"
<<
"DFwd_BAA_Avx2_NHWC_YXCK"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
9cefc261
...
@@ -28,7 +28,8 @@ template <typename GridwiseGemm,
...
@@ -28,7 +28,8 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
void
kernel_gemm_avx_mxn
(
const
FloatA
*
__restrict__
p_a_grid
,
void
kernel_gemm_avx_mxn
(
const
GridwiseGemm
&
gridwise_gemm
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
&
a_grid_desc
,
const
AGridDesc
&
a_grid_desc
,
...
@@ -38,7 +39,7 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
...
@@ -38,7 +39,7 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
{
{
G
ridwise
G
emm
::
Run
(
p_a_grid
,
g
ridwise
_g
emm
.
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_grid_desc
,
a_grid_desc
,
...
@@ -58,14 +59,10 @@ template <typename FloatA,
...
@@ -58,14 +59,10 @@ template <typename FloatA,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
...
@@ -75,12 +72,19 @@ template <typename FloatA,
...
@@ -75,12 +72,19 @@ template <typename FloatA,
>
>
struct
GridwiseGemmAvx2_MxN
struct
GridwiseGemmAvx2_MxN
{
{
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
GridwiseGemmAvx2_MxN
(
const
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable_
)
:
dynamic_tunable
(
dynamic_tunable_
)
{
}
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
const
AGridDesc
&
a_grid_desc
)
...
@@ -238,16 +242,21 @@ struct GridwiseGemmAvx2_MxN
...
@@ -238,16 +242,21 @@ struct GridwiseGemmAvx2_MxN
return
ck
::
make_multi_index
(
i_m
,
i_n
);
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
const
CGridDesc
&
c_grid_desc
)
{
{
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
is_valid
=
true
;
bool
is_valid
=
true
;
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
// if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value &&
// dynamic_tunable.gemm_n_per_block < GemmN)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
&&
dynamic_tunable
.
n_per_block
<
GemmN
)
is_valid
&=
false
;
is_valid
&=
false
;
}
}
else
else
...
@@ -259,19 +268,19 @@ struct GridwiseGemmAvx2_MxN
...
@@ -259,19 +268,19 @@ struct GridwiseGemmAvx2_MxN
return
is_valid
;
return
is_valid
;
}
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
&
a_grid_desc
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
const
{
{
ck
::
index_t
m_per_block
=
MPerB
lock
;
ck
::
index_t
m_per_block
=
dynamic_tunable
.
m_per_b
lock
;
ck
::
index_t
n_per_block
=
NPerB
lock
;
ck
::
index_t
n_per_block
=
dynamic_tunable
.
n_per_b
lock
;
ck
::
index_t
k_per_block
=
KPerB
lock
;
ck
::
index_t
k_per_block
=
dynamic_tunable
.
k_per_b
lock
;
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
...
@@ -297,7 +306,6 @@ struct GridwiseGemmAvx2_MxN
...
@@ -297,7 +306,6 @@ struct GridwiseGemmAvx2_MxN
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// gemm MN to utilize micro kernel>{};
...
@@ -323,7 +331,9 @@ struct GridwiseGemmAvx2_MxN
...
@@ -323,7 +331,9 @@ struct GridwiseGemmAvx2_MxN
// TODO: openmp aware ordering
// TODO: openmp aware ordering
//
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
{
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
...
@@ -467,7 +477,9 @@ struct GridwiseGemmAvx2_MxN
...
@@ -467,7 +477,9 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
{
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
View file @
9cefc261
...
@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
...
@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
void
kernel_gemm_bias_activation_add_avx_mxn
(
const
FloatA
*
__restrict__
p_a_grid
,
void
kernel_gemm_bias_activation_add_avx_mxn
(
const
GridwiseGemm
&
gridwise_gemm
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
...
@@ -46,7 +47,7 @@ void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid
...
@@ -46,7 +47,7 @@ void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
{
{
G
ridwise
G
emm
::
Run
(
p_a_grid
,
g
ridwise
_g
emm
.
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_c0_grid
,
p_c0_grid
,
...
@@ -74,14 +75,10 @@ template <typename FloatA,
...
@@ -74,14 +75,10 @@ template <typename FloatA,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
...
@@ -91,12 +88,19 @@ template <typename FloatA,
...
@@ -91,12 +88,19 @@ template <typename FloatA,
>
>
struct
GridwiseGemmBiasActivationAddAvx2_MxN
struct
GridwiseGemmBiasActivationAddAvx2_MxN
{
{
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
GridwiseGemmBiasActivationAddAvx2_MxN
(
const
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable_
)
:
dynamic_tunable
(
dynamic_tunable_
)
{
}
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
const
AGridDesc
&
a_grid_desc
)
...
@@ -254,16 +258,20 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
...
@@ -254,16 +258,20 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return
ck
::
make_multi_index
(
i_m
,
i_n
);
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
const
CGridDesc
&
c_grid_desc
)
{
{
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
is_valid
=
true
;
bool
is_valid
=
true
;
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
&&
dynamic_tunable
.
n_per_block
<
GemmN
)
is_valid
&=
false
;
is_valid
&=
false
;
}
}
else
else
...
@@ -275,23 +283,23 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
...
@@ -275,23 +283,23 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return
is_valid
;
return
is_valid
;
}
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
const
AGridDesc
&
a_grid_desc
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
C0GridDesc
&
c0_grid_desc
,
const
C0GridDesc
&
c0_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
const
{
{
ck
::
index_t
m_per_block
=
MPerB
lock
;
ck
::
index_t
m_per_block
=
dynamic_tunable
.
m_per_b
lock
;
ck
::
index_t
n_per_block
=
NPerB
lock
;
ck
::
index_t
n_per_block
=
dynamic_tunable
.
n_per_b
lock
;
ck
::
index_t
k_per_block
=
KPerB
lock
;
ck
::
index_t
k_per_block
=
dynamic_tunable
.
k_per_b
lock
;
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
...
@@ -323,7 +331,6 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
...
@@ -323,7 +331,6 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
// gemm MN to utilize micro kernel>{};
...
@@ -349,7 +356,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
...
@@ -349,7 +356,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
// TODO: openmp aware ordering
// TODO: openmp aware ordering
//
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
{
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
...
@@ -505,7 +515,9 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
...
@@ -505,7 +515,9 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
}
}
}
}
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
{
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
9cefc261
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
9cefc261
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
View file @
9cefc261
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <stdlib.h>
#include <utility>
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
...
@@ -41,83 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
...
@@ -41,83 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
// clang-format on
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
{});
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
// clang-format on
));
}
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c
(
...
@@ -125,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
...
@@ -125,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c(
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
{});
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
...
@@ -41,87 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
...
@@ -41,87 +42,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, true
, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true,
true,
c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, true
, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true,
true,
c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, true
, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true,
true,
c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true,
false, c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
{});
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
// clang-format on
));
}
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
...
@@ -129,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
...
@@ -129,14 +81,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c(
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
{});
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
{});
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
...
@@ -40,86 +41,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
...
@@ -40,86 +41,38 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
{});
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
// clang-format on
));
}
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
...
@@ -127,14 +80,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
...
@@ -127,14 +80,57 @@ void add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c(
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment