Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9cefc261
"include/ck/utility/get_id.hpp" did not exist on "d6d9a8e4cee89feef6758f825cfea1588fec16da"
Commit
9cefc261
authored
Jun 13, 2022
by
carlushuang
Browse files
refactor device instance to use less template, more dynamic tunable
parent
6dfb4e78
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1006 additions
and
1016 deletions
+1006
-1016
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
+1
-1
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+1
-1
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+5
-4
include/ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp
...de/ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp
+19
-0
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+61
-63
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
...on/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
+28
-30
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
+24
-28
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
+55
-63
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
..._convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
+55
-63
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
+55
-60
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+37
-25
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
...ation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
+41
-29
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+134
-137
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
..._fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+137
-144
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
+139
-144
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
+72
-73
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
...2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
+71
-76
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
+71
-75
No files found.
example/cpu_01_conv2d_fwd/cpu_conv2d_fwd.cpp
View file @
9cefc261
...
...
@@ -21,7 +21,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
#define TEST_LAYOUT TEST_LAYOUT_NHWC_
K
YXCK
8
_NHWK
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
9cefc261
...
...
@@ -17,7 +17,7 @@
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_YXCK_NHWK 2
#define TEST_LAYOUT TEST_LAYOUT_NHWC_YXCK_NHWK
#define TEST_LAYOUT TEST_LAYOUT_NHWC_
K
YXCK
8
_NHWK
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
9cefc261
...
...
@@ -18,8 +18,6 @@ template <typename FloatA,
typename
BBlockDesc
,
typename
CDesc
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadMNAccessOrder
// how we acces gemm MN to utilize micro kernel
>
...
...
@@ -83,8 +81,11 @@ struct BlockwiseGemmAvx2_MxN
else
{
// N/8 * K * 8
return
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
*
b_block_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
// return b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
// 1>{}].GetUpperLengths()[Number<1>{}] *
// b_block_desc.GetTransforms()[Number<BBlockDesc::GetNumOfTransform() -
// 1>{}].GetUpperLengths()[Number<2>{}];
return
b_block_desc
.
GetLength
(
Number
<
1
>
{})
*
b_block_desc
.
GetLength
(
Number
<
2
>
{});
}
}
...
...
include/ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp
View file @
9cefc261
...
...
@@ -3,6 +3,7 @@
#include <iostream>
#include "device_base_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -77,6 +78,24 @@ using DeviceConvFwdBiasActivationAddPtr =
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
struct
DeviceConvFwdDynamicTunable
{
ck
::
index_t
m_per_block
;
ck
::
index_t
n_per_block
;
ck
::
index_t
k_per_block
;
// ck::index_t m_per_thread;
// ck::index_t n_per_thread;
// bool use_a_local_buffer;
// bool use_b_local_buffer;
// bool use_c_local_buffer;
// ConvolutionForwardSpecialization_t forward_spec;
// ConvolutionForwardGemmKSpecialization_t gemm_k_spec;
ConvolutionForwardBlockLoopOverSpecialization_t
loop_over_spec
;
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
9cefc261
...
...
@@ -30,11 +30,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
...
...
@@ -65,17 +61,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
...
@@ -106,45 +97,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
ck
::
index_t
gemm_n_padded
=
...
...
@@ -576,6 +528,48 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
...
...
@@ -620,20 +614,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -711,11 +703,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
...
...
@@ -738,6 +734,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -753,6 +750,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -780,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
...
@@ -811,7 +809,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
}
...
...
@@ -825,7 +823,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -868,7 +866,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
...
@@ -908,7 +906,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -925,8 +923,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
"DFwdAvx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
View file @
9cefc261
...
...
@@ -30,11 +30,8 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
//
ConvolutionForwardBlockLoopOverSpecialization_t BlockLoopOverSpecialization,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
...
...
@@ -65,17 +62,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
...
@@ -518,7 +510,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
...
...
@@ -530,10 +523,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
...
...
@@ -545,7 +539,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
...
...
@@ -597,20 +592,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -687,12 +680,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
...
...
@@ -715,6 +711,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -730,6 +727,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -757,7 +755,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
...
@@ -788,7 +786,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
}
...
...
@@ -805,7 +803,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -848,7 +846,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
...
@@ -888,7 +886,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -905,8 +903,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<<
"DFwdAvx2_NHWC_KYXCK8"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp
View file @
9cefc261
...
...
@@ -29,11 +29,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
...
...
@@ -64,17 +60,12 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
...
@@ -514,7 +505,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
...
...
@@ -526,7 +518,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
NPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(KPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
...
...
@@ -538,7 +531,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
...
...
@@ -590,20 +584,18 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -680,12 +672,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
...
...
@@ -708,6 +703,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -723,6 +719,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -750,7 +747,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
...
@@ -781,7 +778,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
}
...
...
@@ -801,7 +798,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
}
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -844,7 +841,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
...
@@ -884,7 +881,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -901,9 +898,8 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<<
"DFwdAvx2_NHWC_YXCK"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_over_spec
)
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_block
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
View file @
9cefc261
...
...
@@ -32,11 +32,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
...
...
@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
...
@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
ck
::
index_t
gemm_n_padded
=
...
...
@@ -598,6 +550,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
...
...
@@ -650,20 +638,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -755,11 +741,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
...
...
@@ -787,6 +777,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -806,6 +797,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -837,7 +829,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
...
@@ -868,7 +860,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
}
...
...
@@ -882,7 +874,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
}
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -929,7 +921,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
...
@@ -973,7 +965,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -990,8 +982,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
<<
"DFwd_BAA_Avx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
View file @
9cefc261
...
...
@@ -32,11 +32,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
...
...
@@ -73,17 +69,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
...
@@ -114,45 +105,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
...
...
@@ -575,6 +527,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
...
...
@@ -627,20 +615,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -732,11 +718,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
...
...
@@ -764,6 +754,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -783,6 +774,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -814,7 +806,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
...
@@ -845,7 +837,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
}
...
...
@@ -862,7 +854,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
}
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -909,7 +901,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
...
@@ -953,7 +945,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -970,8 +962,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
<<
"DFwd_BAA_Avx2_NHWC_KYXCK8"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp
View file @
9cefc261
...
...
@@ -31,11 +31,7 @@ template <typename InDataType,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
...
...
@@ -72,17 +68,12 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
...
...
@@ -111,42 +102,6 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
KPerBlock
,
NPerBlock
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_k
,
gemm_n
));
...
...
@@ -568,6 +523,42 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
...
...
@@ -620,20 +611,18 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -725,11 +714,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
if
(
!
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
...
...
@@ -757,6 +750,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -776,6 +770,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
...
...
@@ -807,7 +802,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
...
...
@@ -838,7 +833,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerB
lock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
==
0
))
return
false
;
}
...
...
@@ -858,7 +853,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
}
// Gridwise GEMM size
return
G
ridwise
G
emm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
g
ridwise
_g
emm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -905,7 +900,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
...
...
@@ -949,7 +944,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -966,8 +961,8 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Outpu
<<
"DFwd_BAA_Avx2_NHWC_YXCK"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopO
ver
S
pec
ialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerB
lock
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_o
ver
_s
pec
)
<<
"_BT"
<<
gridwise_gemm
.
dynamic_tunable
.
m_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
n_per_block
<<
"x"
<<
gridwise_gemm
.
dynamic_tunable
.
k_per_b
lock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
9cefc261
...
...
@@ -28,7 +28,8 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
kernel_gemm_avx_mxn
(
const
FloatA
*
__restrict__
p_a_grid
,
void
kernel_gemm_avx_mxn
(
const
GridwiseGemm
&
gridwise_gemm
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
&
a_grid_desc
,
...
...
@@ -38,7 +39,7 @@ void kernel_gemm_avx_mxn(const FloatA* __restrict__ p_a_grid,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
G
ridwise
G
emm
::
Run
(
p_a_grid
,
g
ridwise
_g
emm
.
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_grid_desc
,
...
...
@@ -58,14 +59,10 @@ template <typename FloatA,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
...
...
@@ -75,12 +72,19 @@ template <typename FloatA,
>
struct
GridwiseGemmAvx2_MxN
{
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
GridwiseGemmAvx2_MxN
(
const
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable_
)
:
dynamic_tunable
(
dynamic_tunable_
)
{
}
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
...
...
@@ -238,7 +242,7 @@ struct GridwiseGemmAvx2_MxN
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
{
...
...
@@ -247,7 +251,12 @@ struct GridwiseGemmAvx2_MxN
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
{
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
// if(std::is_same<BlockMNKAccessOrder, ck::Sequence<0, 2, 1>>::value &&
// dynamic_tunable.gemm_n_per_block < GemmN)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
&&
dynamic_tunable
.
n_per_block
<
GemmN
)
is_valid
&=
false
;
}
else
...
...
@@ -259,7 +268,7 @@ struct GridwiseGemmAvx2_MxN
return
is_valid
;
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
&
a_grid_desc
,
...
...
@@ -267,11 +276,11 @@ struct GridwiseGemmAvx2_MxN
const
CGridDesc
&
c_grid_desc
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
const
{
ck
::
index_t
m_per_block
=
MPerB
lock
;
ck
::
index_t
n_per_block
=
NPerB
lock
;
ck
::
index_t
k_per_block
=
KPerB
lock
;
ck
::
index_t
m_per_block
=
dynamic_tunable
.
m_per_b
lock
;
ck
::
index_t
n_per_block
=
dynamic_tunable
.
n_per_b
lock
;
ck
::
index_t
k_per_block
=
dynamic_tunable
.
k_per_b
lock
;
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
...
...
@@ -297,7 +306,6 @@ struct GridwiseGemmAvx2_MxN
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
...
...
@@ -323,7 +331,9 @@ struct GridwiseGemmAvx2_MxN
// TODO: openmp aware ordering
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
...
...
@@ -467,7 +477,9 @@ struct GridwiseGemmAvx2_MxN
}
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
View file @
9cefc261
...
...
@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
kernel_gemm_bias_activation_add_avx_mxn
(
const
FloatA
*
__restrict__
p_a_grid
,
void
kernel_gemm_bias_activation_add_avx_mxn
(
const
GridwiseGemm
&
gridwise_gemm
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
...
...
@@ -46,7 +47,7 @@ void kernel_gemm_bias_activation_add_avx_mxn(const FloatA* __restrict__ p_a_grid
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
G
ridwise
G
emm
::
Run
(
p_a_grid
,
g
ridwise
_g
emm
.
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_c0_grid
,
...
...
@@ -74,14 +75,10 @@ template <typename FloatA,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
...
...
@@ -91,12 +88,19 @@ template <typename FloatA,
>
struct
GridwiseGemmBiasActivationAddAvx2_MxN
{
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
GridwiseGemmBiasActivationAddAvx2_MxN
(
const
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable_
)
:
dynamic_tunable
(
dynamic_tunable_
)
{
}
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
...
...
@@ -254,7 +258,7 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
{
...
...
@@ -263,7 +267,11 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
{
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
&&
dynamic_tunable
.
n_per_block
<
GemmN
)
is_valid
&=
false
;
}
else
...
...
@@ -275,7 +283,7 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return
is_valid
;
}
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
...
...
@@ -287,11 +295,11 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
const
C1GridDesc
&
c1_grid_desc
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
const
{
ck
::
index_t
m_per_block
=
MPerB
lock
;
ck
::
index_t
n_per_block
=
NPerB
lock
;
ck
::
index_t
k_per_block
=
KPerB
lock
;
ck
::
index_t
m_per_block
=
dynamic_tunable
.
m_per_b
lock
;
ck
::
index_t
n_per_block
=
dynamic_tunable
.
n_per_b
lock
;
ck
::
index_t
k_per_block
=
dynamic_tunable
.
k_per_b
lock
;
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
...
...
@@ -323,7 +331,6 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
...
...
@@ -349,7 +356,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
// TODO: openmp aware ordering
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
...
...
@@ -505,7 +515,9 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
}
}
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
{
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <utility>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
...
...
@@ -48,18 +49,24 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
...
...
@@ -70,11 +77,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
...
...
@@ -85,12 +98,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
...
...
@@ -117,10 +135,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
...
...
@@ -131,11 +156,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
...
...
@@ -146,12 +177,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances = std::tu
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
...
...
@@ -178,48 +214,9 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
{});
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
}
// namespace device_conv2d_fwd_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp"
...
...
@@ -41,20 +42,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true, true, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true, true, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true, true, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, false, false, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true, false, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true, false, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true, true, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true, true, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true, true, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, false, false, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true, false, c_local_buf>
({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
...
...
@@ -65,11 +71,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
...
...
@@ -80,12 +92,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances = std::tuple
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
...
...
@@ -104,7 +121,6 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
...
...
@@ -113,10 +129,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
...
...
@@ -127,11 +150,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
...
...
@@ -142,12 +171,17 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances = std::
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
...
...
@@ -170,7 +204,6 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
...
...
@@ -179,49 +212,9 @@ using device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances = std::tuple
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances
{});
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
}
// namespace device_conv2d_fwd_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_yxck_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_avx2_nhwc_yxck_nhwk.hpp"
...
...
@@ -40,20 +41,24 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
...
...
@@ -64,11 +69,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
...
...
@@ -79,12 +90,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
...
...
@@ -112,10 +128,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
...
...
@@ -126,11 +149,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
...
...
@@ -141,12 +170,17 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances = std::tu
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
...
...
@@ -178,48 +212,9 @@ using device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_yxck_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_yxck_nhwk_f32_mt_relu_instances
{});
DEVICE_CONV2D_FWD_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
// clang-format on
));
}
}
// namespace device_conv2d_fwd_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <utility>
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
...
...
@@ -41,18 +42,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MNK}), \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
...
...
@@ -63,11 +71,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances =
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
...
...
@@ -78,12 +92,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_inst
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
...
...
@@ -110,29 +129,9 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
...
...
@@ -41,21 +42,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, true
, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, true
, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, true
, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float
, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block
, m_per_thread, n_per_thread, true
, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true,
true,
c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, true,
true,
c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true,
true,
c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC ,
2
, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop,
2
, m_per_thread, n_per_thread, true,
false, c_local_buf, bias_along_m>
({m_per_block, n_per_block, k_per_block, LoopOver_MNK})
, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, true, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , 2, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN}), \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float, float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, 2, m_per_thread, n_per_thread, true, false, c_local_buf, bias_along_m>({m_per_block, n_per_block, k_per_block, LoopOver_MKN})
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
...
...
@@ -66,12 +71,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
...
...
@@ -82,12 +92,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_in
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
...
...
@@ -114,29 +129,9 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instanc
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_yxck_nhwk_instance.cpp
View file @
9cefc261
#include <stdlib.h>
#include <utility>
#include "config.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk.hpp"
...
...
@@ -40,21 +41,25 @@ static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MNK
})
,
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
true
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
2
,
m_per_thread
,
n_per_thread
,
false
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
}),
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_Y_X_C_K_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
2
,
m_per_thread
,
n_per_thread
,
true
,
false
,
c_local_buf
,
bias_along_m
>
({
m_per_block
,
n_per_block
,
k_per_block
,
LoopOver_MKN
})
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
...
...
@@ -65,11 +70,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances =
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
// clang-format on
));
}
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
...
...
@@ -80,12 +91,17 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_inst
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
=
std
::
tuple
<
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
...
...
@@ -112,29 +128,9 @@ using device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_local_c
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_local_c_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_yxck_nhwk_f32_mt_instances
{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_YXCK_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
// clang-format on
));
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment