Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
71254ddd
Commit
71254ddd
authored
Jun 05, 2022
by
carlushuang
Browse files
optimize multi-thread case by support not using LocalA/LocalB
parent
dc536427
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
3241 additions
and
2949 deletions
+3241
-2949
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+15
-3
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+41
-9
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
...on/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
+931
-899
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
+1016
-984
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
..._convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
+41
-9
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+97
-55
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
...ation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
+698
-655
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+9
-5
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+94
-77
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
..._fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+102
-77
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
+142
-133
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
...2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
+55
-43
No files found.
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
71254ddd
...
...
@@ -41,6 +41,10 @@ struct BlockwiseGemmAvx2_MxN
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
ASliceLengths
=
MultiIndex
<
nDimA
>
;
using
BSliceLengths
=
MultiIndex
<
nDimB
>
;
using
CSliceLengths
=
MultiIndex
<
nDimC
>
;
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
...
...
@@ -89,6 +93,7 @@ struct BlockwiseGemmAvx2_MxN
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
#if 0
static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc)
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
...
...
@@ -134,6 +139,7 @@ struct BlockwiseGemmAvx2_MxN
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
#endif
static
ck
::
index_t
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
...
...
@@ -175,14 +181,17 @@ struct BlockwiseGemmAvx2_MxN
static
void
Run
(
const
ABlockDesc
&
a_block_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
IndexA
&
/* a_origin */
,
const
ASliceLengths
&
a_slice_length
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockBuffer
&
b_block_buf
,
const
IndexB
&
/* b_origin */
,
const
BSliceLengths
&
b_slice_length
,
const
CDesc
&
c_desc
,
CBuffer
&
c_buf
,
const
IndexC
&
/* c_origin */
,
const
CSliceLengths
&
c_slice_length
,
bool
is_accumulate_c
=
true
)
{
...
...
@@ -192,9 +201,9 @@ struct BlockwiseGemmAvx2_MxN
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const
auto
k_per_block
=
GetKPerBlock
(
a_block_desc
)
;
const
auto
m_per_block
=
GetMPerBlock
(
a_block_desc
)
;
const
auto
n_per_block
=
GetNPerBlock
(
b_block_desc
)
;
const
auto
k_per_block
=
a_slice_length
[
Number
<
1
>
{}]
;
const
auto
m_per_block
=
c_slice_length
[
Number
<
0
>
{}]
;
const
auto
n_per_block
=
c_slice_length
[
Number
<
1
>
{}]
;
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
...
...
@@ -206,6 +215,9 @@ struct BlockwiseGemmAvx2_MxN
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc,
// m_per_block, n_per_block, k_per_block);
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
71254ddd
...
...
@@ -107,22 +107,43 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
...
...
@@ -564,7 +585,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -575,7 +596,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -786,12 +807,23 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
if
constexpr
((
!
UseALocalBuffer
||
!
UseBLocalBuffer
)
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
View file @
71254ddd
...
...
@@ -106,24 +106,6 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
...
...
@@ -532,6 +514,45 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
...
...
@@ -541,7 +562,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -552,7 +573,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -763,7 +784,9 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
...
...
@@ -772,6 +795,15 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
return
false
;
if
constexpr
(
!
UseALocalBuffer
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
View file @
71254ddd
...
...
@@ -115,22 +115,43 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
...
...
@@ -586,7 +607,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -597,7 +618,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -843,12 +864,23 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Outpu
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
}
if
constexpr
((
!
UseALocalBuffer
||
!
UseBLocalBuffer
)
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
View file @
71254ddd
...
...
@@ -115,22 +115,43 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
...
...
@@ -563,7 +584,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
false
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -574,7 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
false
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
...
...
@@ -820,7 +841,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
}
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
...
...
@@ -829,6 +852,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
return
false
;
if
constexpr
(
!
UseALocalBuffer
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
71254ddd
...
...
@@ -80,7 +80,11 @@ struct GridwiseGemmAvx2_MxN
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
{
if
constexpr
(
UseALocalBuffer
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
...
@@ -100,8 +104,17 @@ struct GridwiseGemmAvx2_MxN
return
a_block_desc_k_m
;
}
}
else
{
return
a_grid_desc
;
}
}
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
BGridDesc
&
b_grid_desc
)
{
if
constexpr
(
UseBLocalBuffer
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
...
...
@@ -115,13 +128,19 @@ struct GridwiseGemmAvx2_MxN
else
{
// B : N/8, K, N8
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
else
{
return
b_grid_desc
;
}
}
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
,
...
...
@@ -262,10 +281,10 @@ struct GridwiseGemmAvx2_MxN
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpre
t_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
cons
t_cast
<
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpre
t_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
cons
t_cast
<
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
...
...
@@ -274,8 +293,8 @@ struct GridwiseGemmAvx2_MxN
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
...
...
@@ -320,14 +339,14 @@ struct GridwiseGemmAvx2_MxN
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
...
...
@@ -338,21 +357,27 @@ struct GridwiseGemmAvx2_MxN
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
...
...
@@ -395,8 +420,8 @@ struct GridwiseGemmAvx2_MxN
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
...
...
@@ -412,12 +437,17 @@ struct GridwiseGemmAvx2_MxN
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetASliceLength
(
mc_size
,
kc_size
),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBSliceLength
(
kc_size
,
nc_size
),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
i_kc
!=
0
);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
...
...
@@ -450,14 +480,14 @@ struct GridwiseGemmAvx2_MxN
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
...
...
@@ -468,21 +498,27 @@ struct GridwiseGemmAvx2_MxN
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
...
...
@@ -503,7 +539,7 @@ struct GridwiseGemmAvx2_MxN
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
...
...
@@ -519,7 +555,7 @@ struct GridwiseGemmAvx2_MxN
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
...
...
@@ -543,12 +579,18 @@ struct GridwiseGemmAvx2_MxN
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetASliceLength
(
mc_size
,
kc_size
),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBSliceLength
(
kc_size
,
nc_size
),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
View file @
71254ddd
...
...
@@ -96,7 +96,11 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
{
if
constexpr
(
UseALocalBuffer
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
...
...
@@ -116,8 +120,17 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
return
a_block_desc_k_m
;
}
}
else
{
return
a_grid_desc
;
}
}
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
BGridDesc
&
b_grid_desc
)
{
if
constexpr
(
UseBLocalBuffer
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
...
...
@@ -131,13 +144,19 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
else
{
// B : N/8, K, N8
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
else
{
return
b_grid_desc
;
}
}
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
,
...
...
@@ -282,10 +301,10 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpre
t_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
cons
t_cast
<
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpre
t_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
cons
t_cast
<
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
...
...
@@ -300,8 +319,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
...
...
@@ -346,14 +365,14 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
...
...
@@ -364,21 +383,27 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
...
...
@@ -428,8 +453,8 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
...
...
@@ -445,12 +470,18 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetASliceLength
(
mc_size
,
kc_size
),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBSliceLength
(
kc_size
,
nc_size
),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
i_kc
!=
0
);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
...
...
@@ -487,14 +518,14 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
...
...
@@ -505,21 +536,27 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
...
...
@@ -540,7 +577,7 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
...
...
@@ -556,7 +593,7 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
...
...
@@ -590,12 +627,18 @@ struct GridwiseGemmBiasActivationAddAvx2_MxN
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetASliceLength
(
mc_size
,
kc_size
),
b_block_desc
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBSliceLength
(
kc_size
,
nc_size
),
c_block_desc
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
71254ddd
...
...
@@ -519,7 +519,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
...
...
@@ -917,14 +917,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{
// TODO: weight NHWC not support this
// KYXC weigh should not support this
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
else
{
...
...
@@ -1132,12 +1133,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{}
if
constexpr
(
BypassTransfer
)
{
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
else
{
const
ck
::
index_t
n0_per_block
=
slice_length
[
Number
<
0
>
{}];
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
71254ddd
...
...
@@ -47,121 +47,138 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
71254ddd
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
71254ddd
...
...
@@ -40,68 +40,77 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
71254ddd
...
...
@@ -40,69 +40,81 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment