Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
71254ddd
Commit
71254ddd
authored
Jun 05, 2022
by
carlushuang
Browse files
optimize multi-thread case by support not using LocalA/LocalB
parent
dc536427
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
3241 additions
and
2949 deletions
+3241
-2949
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
...ude/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
+15
-3
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+41
-9
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
...on/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
+931
-899
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
...ce_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
+1016
-984
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
..._convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
+41
-9
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+97
-55
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
...ation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
+698
-655
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+9
-5
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+94
-77
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
..._fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
+102
-77
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
...nv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
+142
-133
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
...2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
+55
-43
No files found.
include/ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp
View file @
71254ddd
...
@@ -41,6 +41,10 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -41,6 +41,10 @@ struct BlockwiseGemmAvx2_MxN
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexB
=
MultiIndex
<
nDimB
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
IndexC
=
MultiIndex
<
nDimC
>
;
using
ASliceLengths
=
MultiIndex
<
nDimA
>
;
using
BSliceLengths
=
MultiIndex
<
nDimB
>
;
using
CSliceLengths
=
MultiIndex
<
nDimC
>
;
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
ACoord
=
decltype
(
make_tensor_coordinate
(
ABlockDesc
{},
IndexA
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
BCoord
=
decltype
(
make_tensor_coordinate
(
BBlockDesc
{},
IndexB
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
using
CCoord
=
decltype
(
make_tensor_coordinate
(
CDesc
{},
IndexC
{}));
...
@@ -89,6 +93,7 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -89,6 +93,7 @@ struct BlockwiseGemmAvx2_MxN
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
return
c_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
}
#if 0
static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc)
static ck::index_t GetMPerBlock(const ABlockDesc& a_block_desc)
{
{
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
if constexpr(std::is_same<typename ThreadwiseGemm_Dispatch::MatrixALayout,
...
@@ -134,6 +139,7 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -134,6 +139,7 @@ struct BlockwiseGemmAvx2_MxN
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
b_block_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
}
}
}
}
#endif
static
ck
::
index_t
static
ck
::
index_t
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
GetABlockStartOffset
(
const
ABlockDesc
&
a_block_desc
,
const
index_t
i_m
,
const
index_t
)
...
@@ -175,14 +181,17 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -175,14 +181,17 @@ struct BlockwiseGemmAvx2_MxN
static
void
Run
(
const
ABlockDesc
&
a_block_desc
,
static
void
Run
(
const
ABlockDesc
&
a_block_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
ABlockBuffer
&
a_block_buf
,
const
IndexA
&
/* a_origin */
,
const
IndexA
&
/* a_origin */
,
const
ASliceLengths
&
a_slice_length
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
IndexB
&
/* b_origin */
,
const
IndexB
&
/* b_origin */
,
const
BSliceLengths
&
b_slice_length
,
const
CDesc
&
c_desc
,
const
CDesc
&
c_desc
,
CBuffer
&
c_buf
,
CBuffer
&
c_buf
,
const
IndexC
&
/* c_origin */
,
const
IndexC
&
/* c_origin */
,
const
CSliceLengths
&
c_slice_length
,
bool
is_accumulate_c
=
true
)
bool
is_accumulate_c
=
true
)
{
{
...
@@ -192,9 +201,9 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -192,9 +201,9 @@ struct BlockwiseGemmAvx2_MxN
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
// printf("lda:%d, ldb:%d, ldc:%d\n", lda, ldb, ldc);
const
auto
k_per_block
=
GetKPerBlock
(
a_block_desc
)
;
const
auto
k_per_block
=
a_slice_length
[
Number
<
1
>
{}]
;
const
auto
m_per_block
=
GetMPerBlock
(
a_block_desc
)
;
const
auto
m_per_block
=
c_slice_length
[
Number
<
0
>
{}]
;
const
auto
n_per_block
=
GetNPerBlock
(
b_block_desc
)
;
const
auto
n_per_block
=
c_slice_length
[
Number
<
1
>
{}]
;
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
const
auto
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
...
@@ -206,6 +215,9 @@ struct BlockwiseGemmAvx2_MxN
...
@@ -206,6 +215,9 @@ struct BlockwiseGemmAvx2_MxN
param
.
alpha
=
1.0
f
;
// TODO
param
.
alpha
=
1.0
f
;
// TODO
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
param
.
accmulate_c
=
is_accumulate_c
?
1
:
0
;
// printf("xxx lda:%u, ldb:%u, ldc:%u, mpb:%u, npb:%u, kpb:%u\n", lda, ldb, ldc,
// m_per_block, n_per_block, k_per_block);
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
if
constexpr
(
std
::
is_same
<
ThreadMNAccessOrder
,
ck
::
Sequence
<
0
,
1
>>::
value
)
{
{
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
for
(
ck
::
index_t
i_m
=
0
;
i_m
<
m_per_block
;
i_m
+=
m_per_thread
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
71254ddd
...
@@ -108,20 +108,41 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -108,20 +108,41 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
constexpr
auto
GetInputBlockDescriptor
()
static
constexpr
auto
GetInputBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
static
constexpr
auto
GetWeightBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
if
constexpr
(
UseBLocalBuffer
)
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
{
KPerBlock
,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
static
constexpr
auto
GetOutputBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
...
@@ -564,7 +585,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -564,7 +585,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
InElementwiseOperation
,
false
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
GemmKSpecialization
>
;
...
@@ -575,7 +596,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -575,7 +596,7 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
BGridDesc
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
WeiElementwiseOperation
,
false
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
GemmKSpecialization
>
;
...
@@ -786,12 +807,23 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -786,12 +807,23 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if
constexpr
(
GemmKSpecialization
==
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
return
false
;
}
}
if
constexpr
((
!
UseALocalBuffer
||
!
UseBLocalBuffer
)
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
// Gridwise GEMM size
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
...
...
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxck8_nhwk.hpp
View file @
71254ddd
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP
#ifndef DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <numeric>
#include <numeric>
#include "device.hpp"
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "gridwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
cpu
{
namespace
device
{
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
>
bool
UseCLocalBuffer
>
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
;
using
DeviceOp
=
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
;
using
ADataType
=
InDataType
;
using
ADataType
=
InDataType
;
using
BDataType
=
WeiDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
OutDataType
;
using
CDataType
=
OutDataType
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
// TODO make A/B datatype different
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
NonTemporalStore
>
{};
}
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
NonTemporalStore
>
{};
}
}
else
else
{
{
// static_assert(false, "invalid Mr/Nr");
// static_assert(false, "invalid Mr/Nr");
}
}
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
const
auto
out_gemm_m_n_grid_desc
=
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
out_gemm_m_n_grid_desc
;
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
{
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
ck
::
index_t
C
,
}
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
{
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
}
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
{
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
const
auto
out_gemm_m_n_grid_desc
=
{
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
return
out_gemm_m_n_grid_desc
;
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
}
if
constexpr
(
ConvForwardSpecialization
==
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
{
ck
::
index_t
C
,
const
auto
in_gemm_m_k_grid_desc
=
ck
::
index_t
gemm_m
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
return
in_gemm_m_k_grid_desc
;
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
}
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
else
if
constexpr
(
ConvForwardSpecialization
==
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
{
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
auto
in_n_wi_c_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
{
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
Wo
=
output_spatial_lengths
[
0
];
in_n_wi_c_grid_desc
,
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
if
constexpr
(
ConvForwardSpecialization
==
make_pass_through_transform
(
C
)),
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
{
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_c_grid_desc
,
return
in_gemm_m_k_grid_desc
;
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
}
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
else
if
constexpr
(
ConvForwardSpecialization
==
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
return
in_gemm_m_k_grid_desc
;
const
auto
in_n_wi_c_grid_desc
=
}
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
else
{
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
X
=
filter_spatial_lengths
[
0
];
in_n_wi_c_grid_desc
,
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
InRightPadW
=
input_right_pads
[
0
];
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
const
auto
in_n_wi_c_grid_desc
=
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_c_grid_desc
,
in_n_wi_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
return
in_gemm_m_k_grid_desc
;
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
else
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
{
in_n_wip_c_grid_desc
,
const
index_t
X
=
filter_spatial_lengths
[
0
];
make_tuple
(
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
make_pass_through_transform
(
N
),
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
const
index_t
InRightPadW
=
input_right_pads
[
0
];
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
const
auto
in_n_wi_c_grid_desc
=
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_gemm_m_k_grid_desc
=
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
in_n_wi_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_tuple
(
make_pass_through_transform
(
N
),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
in_gemm_m_k_grid_desc
;
}
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
}
in_n_wip_c_grid_desc
,
make_tuple
(
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
make_pass_through_transform
(
N
),
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
ck
::
index_t
C
,
make_pass_through_transform
(
C
)),
ck
::
index_t
gemm_m
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
ck
::
index_t
gemm_k
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
make_merge_transform
(
make_tuple
(
X
,
C
))),
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
return
in_gemm_m_k_grid_desc
;
const
index_t
Wi
=
input_spatial_lengths
[
1
];
}
}
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
ck
::
index_t
C
,
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
if
constexpr
(
ConvForwardSpecialization
==
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
{
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
return
in_gemm_m_k_grid_desc
;
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
}
{
else
if
constexpr
(
ConvForwardSpecialization
==
const
index_t
Hi
=
input_spatial_lengths
[
0
];
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
const
index_t
Wi
=
input_spatial_lengths
[
1
];
{
const
auto
in_n_hi_wi_c_grid_desc
=
const
index_t
Ho
=
output_spatial_lengths
[
0
];
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
in_n_hi_wi_c_grid_desc
,
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
if
constexpr
(
ConvForwardSpecialization
==
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
make_pass_through_transform
(
C
)),
{
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
const
auto
in_gemm_m_k_grid_desc
=
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
const
auto
in_gemm_m_k_grid_desc
=
return
in_gemm_m_k_grid_desc
;
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
}
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
else
if
constexpr
(
ConvForwardSpecialization
==
make_pass_through_transform
(
C
)),
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
{
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
return
in_gemm_m_k_grid_desc
;
}
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
else
in_n_hi_wi_c_grid_desc
,
{
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
Y
=
filter_spatial_lengths
[
0
];
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
const
index_t
X
=
filter_spatial_lengths
[
1
];
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
auto
in_gemm_m_k_grid_desc
=
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
const
index_t
InRightPadH
=
input_right_pads
[
0
];
make_pass_through_transform
(
C
)),
const
index_t
InRightPadW
=
input_right_pads
[
1
];
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
return
in_gemm_m_k_grid_desc
;
}
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
else
in_n_hi_wi_c_grid_desc
,
{
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
Y
=
filter_spatial_lengths
[
0
];
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
const
index_t
X
=
filter_spatial_lengths
[
1
];
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
in_n_hip_wip_c_grid_desc
,
make_tuple
(
const
index_t
InRightPadH
=
input_right_pads
[
0
];
make_pass_through_transform
(
N
),
const
index_t
InRightPadW
=
input_right_pads
[
1
];
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
const
auto
in_n_hi_wi_c_grid_desc
=
make_pass_through_transform
(
C
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
const
auto
in_gemm_m_k_grid_desc
=
make_tuple
(
make_pass_through_transform
(
N
),
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
in_gemm_m_k_grid_desc
;
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
}
in_n_hip_wip_c_grid_desc
,
}
make_tuple
(
make_pass_through_transform
(
N
),
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
ck
::
index_t
C
,
make_pass_through_transform
(
C
)),
ck
::
index_t
gemm_m
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
ck
::
index_t
gemm_k
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
return
in_gemm_m_k_grid_desc
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
}
const
index_t
Hi
=
input_spatial_lengths
[
1
];
}
const
index_t
Wi
=
input_spatial_lengths
[
2
];
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
const
index_t
Do
=
output_spatial_lengths
[
0
];
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
const
index_t
Ho
=
output_spatial_lengths
[
1
];
ck
::
index_t
C
,
const
index_t
Wo
=
output_spatial_lengths
[
2
];
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
ck
::
index_t
gemm_m_pad
,
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
if
constexpr
(
ConvForwardSpecialization
==
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
{
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
return
in_gemm_m_k_grid_desc
;
const
index_t
Hi
=
input_spatial_lengths
[
1
];
}
const
index_t
Wi
=
input_spatial_lengths
[
2
];
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
const
index_t
Do
=
output_spatial_lengths
[
0
];
{
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
auto
in_n_di_hi_wi_c_grid_desc
=
const
index_t
Wo
=
output_spatial_lengths
[
2
];
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
in_n_di_hi_wi_c_grid_desc
,
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
if
constexpr
(
ConvForwardSpecialization
==
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
{
make_pass_through_transform
(
C
)),
const
auto
in_gemm_m_k_grid_desc
=
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
return
in_gemm_m_k_grid_desc
;
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
else
if
constexpr
(
ConvForwardSpecialization
==
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
in_n_do_ho_wo_c_grid_desc
,
{
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_pass_through_transform
(
C
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
return
in_gemm_m_k_grid_desc
;
make_tuple
(
make_pass_through_transform
(
N
),
}
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
else
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
{
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
Z
=
filter_spatial_lengths
[
0
];
make_pass_through_transform
(
C
)),
const
index_t
Y
=
filter_spatial_lengths
[
1
];
make_tuple
(
const
index_t
X
=
filter_spatial_lengths
[
2
];
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
make_pass_through_transform
(
C
)),
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
return
in_gemm_m_k_grid_desc
;
const
index_t
InRightPadW
=
input_right_pads
[
2
];
}
else
const
auto
in_n_di_hi_wi_c_grid_desc
=
{
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
X
=
filter_spatial_lengths
[
2
];
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
make_tuple
(
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
InRightPadW
=
input_right_pads
[
2
];
in_n_hip_wip_c_grid_desc
,
make_tuple
(
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_pass_through_transform
(
N
),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
in_n_di_hi_wi_c_grid_desc
,
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_tuple
(
Sequence
<
0
>
{},
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
Sequence
<
1
,
2
>
{},
make_pass_through_transform
(
C
)),
Sequence
<
3
,
4
>
{},
make_tuple
(
Sequence
<
5
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
7
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
in_n_hip_wip_c_grid_desc
,
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_pass_through_transform
(
N
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
return
in_gemm_m_k_grid_desc
;
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
}
make_pass_through_transform
(
C
)),
}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
make_tuple
(
Sequence
<
0
>
{},
{
Sequence
<
1
,
2
>
{},
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
Sequence
<
3
,
4
>
{},
std
::
end
(
output_spatial_lengths
),
Sequence
<
5
,
6
>
{},
1
,
Sequence
<
7
>
{}));
std
::
multiplies
<
ck
::
index_t
>
());
}
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
{
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
std
::
end
(
filter_spatial_lengths
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
1
,
std
::
multiplies
<
ck
::
index_t
>
());
return
in_gemm_m_k_grid_desc
;
}
}
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
// return ck::math::integer_least_multiple(K,
{
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
return
K
;
std
::
end
(
output_spatial_lengths
),
}
1
,
std
::
multiplies
<
ck
::
index_t
>
());
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
}
ck
::
index_t
K
,
ck
::
index_t
C
,
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
{
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
end
(
filter_spatial_lengths
),
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
1
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
multiplies
<
ck
::
index_t
>
());
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
}
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
static
index_t
GetGemmN
(
ck
::
index_t
K
)
using
namespace
ck
;
{
// return ck::math::integer_least_multiple(K,
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const
index_t
GemmN
=
GetGemmN
(
K
);
return
K
;
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
}
// A:
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
const
auto
in_gemm_m_k_grid_desc
=
ck
::
index_t
K
,
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
ck
::
index_t
C
,
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
GemmM
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
GemmK
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
conv_filter_dilations
,
{
input_left_pads
,
using
namespace
ck
;
input_right_pads
);
// B:
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
auto
wei_gemm_n0_k_n1_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
const
index_t
GemmN
=
GetGemmN
(
K
);
// C:
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
// A:
return
make_tuple
(
const
auto
in_gemm_m_k_grid_desc
=
in_gemm_m_k_grid_desc
,
wei_gemm_n0_k_n1_grid_desc
,
out_gemm_m_n_grid_desc
);
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
}
C
,
GemmM
,
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
GemmK
,
static
auto
GetABCGridDesc
()
input_spatial_lengths
,
{
filter_spatial_lengths
,
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
output_spatial_lengths
,
}
conv_filter_strides
,
conv_filter_dilations
,
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
input_left_pads
,
static
auto
GetABCGridDesc
()
input_right_pads
);
{
// B:
return
MakeABCGridDescriptor
(
const
auto
wei_gemm_n0_k_n1_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
// C:
}
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
return
make_tuple
(
static
auto
GetABCGridDesc
()
in_gemm_m_k_grid_desc
,
wei_gemm_n0_k_n1_grid_desc
,
out_gemm_m_n_grid_desc
);
{
}
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
}
static
auto
GetABCGridDesc
()
{
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
static
auto
GetABCGridDesc
()
{
static
constexpr
auto
GetInputBlockDescriptor
()
return
MakeABCGridDescriptor
(
{
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
if
constexpr
(
UseALocalBuffer
)
}
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
}
static
auto
GetABCGridDesc
()
else
{
{
return
MakeABCGridDescriptor
(
return
AGridDesc
{};
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
}
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
static
constexpr
auto
GetWeightBlockDescriptor
()
{
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
if
constexpr
(
UseBLocalBuffer
)
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
{
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
// static constexpr bool UseCLocalBuffer = false;
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
using
AThreadwiseCopy
=
}
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
else
InDataType
,
{
InDataType
,
return
BGridDesc
{};
AGridDesc
,
}
decltype
(
GetInputBlockDescriptor
()),
}
InElementwiseOperation
,
false
,
static
constexpr
auto
GetOutputBlockDescriptor
()
ConvForwardSpecialization
,
{
GemmKSpecialization
>
;
if
constexpr
(
UseCLocalBuffer
)
{
using
BThreadwiseCopy
=
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
<
}
WeiDataType
,
else
WeiDataType
,
{
BGridDesc
,
return
CGridDesc
{};
decltype
(
GetWeightBlockDescriptor
()),
}
WeiElementwiseOperation
,
}
false
,
ConvForwardSpecialization
,
// static constexpr bool UseCLocalBuffer = false;
GemmKSpecialization
>
;
using
AThreadwiseCopy
=
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
OutDataType
,
InDataType
,
OutDataType
,
InDataType
,
CGridDesc
,
AGridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
decltype
(
GetInputBlockDescriptor
()),
OutElementwiseOperation
,
InElementwiseOperation
,
!
UseCLocalBuffer
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
GemmKSpecialization
>
;
using
GridwiseGemm
=
using
BThreadwiseCopy
=
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
<
WeiDataType
,
// WeiDataType,
WeiDataType
,
OutDataType
,
// OutDataType,
WeiDataType
,
AGridDesc
,
// AGridDesc,
BGridDesc
,
BGridDesc
,
// BGridDesc,
decltype
(
GetWeightBlockDescriptor
()),
CGridDesc
,
// CGridDesc,
WeiElementwiseOperation
,
AElementwiseOperation
,
// AElementwiseOperation,
!
UseBLocalBuffer
,
BElementwiseOperation
,
// BElementwiseOperation,
ConvForwardSpecialization
,
CElementwiseOperation
,
// CElementwiseOperation,
GemmKSpecialization
>
;
MPerBlock
,
// MPerBlock,
NPerBlock
,
// NPerBlock,
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
<
KPerBlock
,
// KPerBlock,
OutDataType
,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
OutDataType
,
AThreadwiseCopy
,
// AThreadwiseCopy
CGridDesc
,
BThreadwiseCopy
,
// BThreadwiseCopy
decltype
(
GetOutputBlockDescriptor
()),
CThreadwiseCopy
,
// CThreadwiseCopy
OutElementwiseOperation
,
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
!
UseCLocalBuffer
,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ConvForwardSpecialization
,
UseALocalBuffer
,
// UseALocalBuffer
GemmKSpecialization
>
;
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
using
GridwiseGemm
=
>
;
ck
::
cpu
::
GridwiseGemmAvx2_MxN
<
InDataType
,
// InDataType,
WeiDataType
,
// WeiDataType,
// Argument
OutDataType
,
// OutDataType,
struct
Argument
:
public
BaseArgument
AGridDesc
,
// AGridDesc,
{
BGridDesc
,
// BGridDesc,
Argument
(
const
InDataType
*
p_in_grid
,
CGridDesc
,
// CGridDesc,
const
WeiDataType
*
p_wei_grid
,
AElementwiseOperation
,
// AElementwiseOperation,
OutDataType
*
p_out_grid
,
BElementwiseOperation
,
// BElementwiseOperation,
ck
::
index_t
N
,
CElementwiseOperation
,
// CElementwiseOperation,
ck
::
index_t
K
,
MPerBlock
,
// MPerBlock,
ck
::
index_t
C
,
NPerBlock
,
// NPerBlock,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
KPerBlock
,
// KPerBlock,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
AThreadwiseCopy
,
// AThreadwiseCopy
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
BThreadwiseCopy
,
// BThreadwiseCopy
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
CThreadwiseCopy
,
// CThreadwiseCopy
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
InElementwiseOperation
in_element_op
,
UseALocalBuffer
,
// UseALocalBuffer
WeiElementwiseOperation
wei_element_op
,
UseBLocalBuffer
,
// UseBLocalBuffer
OutElementwiseOperation
out_element_op
)
UseCLocalBuffer
// UseCLocalBuffer
:
p_a_grid_
{
p_in_grid
},
>
;
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
// Argument
a_grid_desc_
{},
struct
Argument
:
public
BaseArgument
b_grid_desc_
{},
{
c_grid_desc_
{},
Argument
(
const
InDataType
*
p_in_grid
,
a_element_op_
{
in_element_op
},
const
WeiDataType
*
p_wei_grid
,
b_element_op_
{
wei_element_op
},
OutDataType
*
p_out_grid
,
c_element_op_
{
out_element_op
},
ck
::
index_t
N
,
Conv_N_
{
N
},
ck
::
index_t
K
,
Conv_K_
{
K
},
ck
::
index_t
C
,
Conv_C_
{
C
},
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
filter_spatial_lengths_
{
filter_spatial_lengths
},
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
conv_filter_strides_
{
conv_filter_strides
},
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
input_left_pads_
{
input_left_pads
},
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
input_right_pads_
{
input_right_pads
}
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
{
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
K
,
InElementwiseOperation
in_element_op
,
C
,
WeiElementwiseOperation
wei_element_op
,
input_spatial_lengths
,
OutElementwiseOperation
out_element_op
)
filter_spatial_lengths
,
:
p_a_grid_
{
p_in_grid
},
output_spatial_lengths
,
p_b_grid_
{
p_wei_grid
},
conv_filter_strides
,
p_c_grid_
{
p_out_grid
},
conv_filter_dilations
,
a_grid_desc_
{},
input_left_pads
,
b_grid_desc_
{},
input_right_pads
);
c_grid_desc_
{},
a_grid_desc_
=
descs
[
I0
];
a_element_op_
{
in_element_op
},
b_grid_desc_
=
descs
[
I1
];
b_element_op_
{
wei_element_op
},
c_grid_desc_
=
descs
[
I2
];
c_element_op_
{
out_element_op
},
}
Conv_N_
{
N
},
Conv_K_
{
K
},
// private:
Conv_C_
{
C
},
const
ADataType
*
p_a_grid_
;
filter_spatial_lengths_
{
filter_spatial_lengths
},
const
BDataType
*
p_b_grid_
;
conv_filter_strides_
{
conv_filter_strides
},
CDataType
*
p_c_grid_
;
input_left_pads_
{
input_left_pads
},
AGridDesc
a_grid_desc_
;
input_right_pads_
{
input_right_pads
}
BGridDesc
b_grid_desc_
;
{
CGridDesc
c_grid_desc_
;
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
K
,
AElementwiseOperation
a_element_op_
;
C
,
BElementwiseOperation
b_element_op_
;
input_spatial_lengths
,
CElementwiseOperation
c_element_op_
;
filter_spatial_lengths
,
// for checking IsSupportedArgument()
output_spatial_lengths
,
index_t
Conv_N_
;
conv_filter_strides
,
index_t
Conv_K_
;
conv_filter_dilations
,
index_t
Conv_C_
;
input_left_pads
,
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
input_right_pads
);
std
::
vector
<
index_t
>
conv_filter_strides_
;
a_grid_desc_
=
descs
[
I0
];
std
::
vector
<
index_t
>
input_left_pads_
;
b_grid_desc_
=
descs
[
I1
];
std
::
vector
<
index_t
>
input_right_pads_
;
c_grid_desc_
=
descs
[
I2
];
};
}
// Invoker
// private:
struct
Invoker
:
public
BaseInvoker
const
ADataType
*
p_a_grid_
;
{
const
BDataType
*
p_b_grid_
;
using
Argument
=
DeviceOp
::
Argument
;
CDataType
*
p_c_grid_
;
AGridDesc
a_grid_desc_
;
float
Run
(
const
Argument
&
arg
,
BGridDesc
b_grid_desc_
;
const
StreamConfig
&
stream_config
=
StreamConfig
{},
CGridDesc
c_grid_desc_
;
int
nrepeat
=
1
)
{
AElementwiseOperation
a_element_op_
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
BElementwiseOperation
b_element_op_
;
{
CElementwiseOperation
c_element_op_
;
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
// for checking IsSupportedArgument()
}
index_t
Conv_N_
;
index_t
Conv_K_
;
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
index_t
Conv_C_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
std
::
vector
<
index_t
>
conv_filter_strides_
;
InDataType
,
std
::
vector
<
index_t
>
input_left_pads_
;
WeiDataType
,
std
::
vector
<
index_t
>
input_right_pads_
;
OutDataType
,
};
AGridDesc
,
BGridDesc
,
// Invoker
CGridDesc
,
struct
Invoker
:
public
BaseInvoker
AElementwiseOperation
,
{
BElementwiseOperation
,
using
Argument
=
DeviceOp
::
Argument
;
CElementwiseOperation
>
;
float
Run
(
const
Argument
&
arg
,
float
ave_time
=
0
;
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
if
(
nrepeat
!=
1
)
{
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
nrepeat
,
{
arg
.
p_a_grid_
,
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
arg
.
p_b_grid_
,
}
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
const
auto
kernel
=
ck
::
cpu
::
kernel_gemm_avx_mxn
<
GridwiseGemm
,
arg
.
a_element_op_
,
InDataType
,
arg
.
b_element_op_
,
WeiDataType
,
arg
.
c_element_op_
);
OutDataType
,
AGridDesc
,
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
BGridDesc
,
// result
CGridDesc
,
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
AElementwiseOperation
,
BElementwiseOperation
,
launch_cpu_kernel
(
kernel
,
CElementwiseOperation
>
;
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
float
ave_time
=
0
;
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
if
(
nrepeat
!=
1
)
arg
.
b_grid_desc_
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
arg
.
c_grid_desc_
,
nrepeat
,
arg
.
a_element_op_
,
arg
.
p_a_grid_
,
arg
.
b_element_op_
,
arg
.
p_b_grid_
,
arg
.
c_element_op_
);
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
return
ave_time
;
arg
.
b_grid_desc_
,
}
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
float
Run
(
const
BaseArgument
*
p_arg
,
arg
.
b_element_op_
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
arg
.
c_element_op_
);
int
nrepeat
=
1
)
override
{
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
// result
}
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
};
launch_cpu_kernel
(
kernel
,
static
constexpr
bool
IsValidCompilationParameter
()
arg
.
p_a_grid_
,
{
arg
.
p_b_grid_
,
// TODO: properly implement this check
arg
.
p_c_grid_
,
return
true
;
arg
.
a_grid_desc_
,
}
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
arg
.
a_element_op_
,
{
arg
.
b_element_op_
,
if
constexpr
(
ConvForwardSpecialization
==
arg
.
c_element_op_
);
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
return
ave_time
;
// check if it's 1x1, stride=1 conv
}
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
float
Run
(
const
BaseArgument
*
p_arg
,
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
const
StreamConfig
&
stream_config
=
StreamConfig
{},
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
int
nrepeat
=
1
)
override
{
{
return
false
;
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
}
}
}
};
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// check if it's 1x1 conv
// TODO: properly implement this check
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
return
true
;
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
}
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
return
false
;
{
}
if
constexpr
(
ConvForwardSpecialization
==
}
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
constexpr
(
GemmKSpecialization
==
// check if it's 1x1, stride=1 conv
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
{
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
return
false
;
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
}
{
return
false
;
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
}
return
false
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
// Gridwise GEMM size
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
{
}
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
{
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
{
}
return
false
;
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
}
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
if
constexpr
(
GemmKSpecialization
==
ck
::
index_t
N
,
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ck
::
index_t
K
,
ConvForwardSpecialization
!=
ck
::
index_t
C
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
{
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
return
false
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
}
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
return
false
;
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
if
constexpr
(
!
UseALocalBuffer
&&
OutElementwiseOperation
out_element_op
)
ConvForwardSpecialization
!=
{
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
return
Argument
{
p_in_grid
,
{
p_wei_grid
,
// TODO: We can support this in the future, as long as figure out how to express tensor
p_out_grid
,
// transform
N
,
return
false
;
K
,
}
C
,
input_spatial_lengths
,
// Gridwise GEMM size
filter_spatial_lengths
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
output_spatial_lengths
,
}
conv_filter_strides
,
conv_filter_dilations
,
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
input_left_pads
,
{
input_right_pads
,
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
in_element_op
,
}
wei_element_op
,
out_element_op
};
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
}
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
ck
::
index_t
N
,
ck
::
index_t
K
,
std
::
unique_ptr
<
BaseArgument
>
ck
::
index_t
C
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
const
void
*
p_wei_grid
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
void
*
p_out_grid
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
ck
::
index_t
N
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
ck
::
index_t
K
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
InElementwiseOperation
in_element_op
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
WeiElementwiseOperation
wei_element_op
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
OutElementwiseOperation
out_element_op
)
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
{
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
return
Argument
{
p_in_grid
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
p_wei_grid
,
InElementwiseOperation
in_element_op
,
p_out_grid
,
WeiElementwiseOperation
wei_element_op
,
N
,
OutElementwiseOperation
out_element_op
)
override
K
,
{
C
,
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
input_spatial_lengths
,
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
filter_spatial_lengths
,
static_cast
<
OutDataType
*>
(
p_out_grid
),
output_spatial_lengths
,
N
,
conv_filter_strides
,
K
,
conv_filter_dilations
,
C
,
input_left_pads
,
input_spatial_lengths
,
input_right_pads
,
filter_spatial_lengths
,
in_element_op
,
output_spatial_lengths
,
wei_element_op
,
conv_filter_strides
,
out_element_op
};
conv_filter_dilations
,
}
input_left_pads
,
input_right_pads
,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
in_element_op
,
wei_element_op
,
std
::
unique_ptr
<
BaseArgument
>
out_element_op
);
MakeArgumentPointer
(
const
void
*
p_in_grid
,
}
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
ck
::
index_t
N
,
{
ck
::
index_t
K
,
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
ck
::
index_t
C
,
}
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
string
GetTypeString
()
const
override
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
{
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
auto
str
=
std
::
stringstream
();
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
if
(
is_local_buffer
)
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
return
"L"
;
InElementwiseOperation
in_element_op
,
else
WeiElementwiseOperation
wei_element_op
,
return
"G"
;
OutElementwiseOperation
out_element_op
)
override
};
{
// clang-format off
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
<<
"DFwdAvx2_NHWC_KYXCK8"
static_cast
<
OutDataType
*>
(
p_out_grid
),
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
N
,
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
K
,
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
C
,
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
input_spatial_lengths
,
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
filter_spatial_lengths
,
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
output_spatial_lengths
,
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
conv_filter_strides
,
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
conv_filter_dilations
,
;
input_left_pads
,
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
input_right_pads
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
in_element_op
,
{
wei_element_op
,
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
out_element_op
);
}
}
// clang-format on
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
return
str
.
str
();
{
}
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
};
}
}
// namespace device
std
::
string
GetTypeString
()
const
override
}
// namespace cpu
{
}
// namespace tensor_operation
auto
str
=
std
::
stringstream
();
}
// namespace ck
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
#endif
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwdAvx2_NHWC_KYXCK8"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp
View file @
71254ddd
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#ifndef DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include <numeric>
#include <numeric>
#include "device.hpp"
#include "device.hpp"
#include "device_base_cpu.hpp"
#include "device_base_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "device_conv_fwd_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "common_header.hpp"
#include "common_header.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "../../gpu/device/tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "gridwise_gemm_bias_activation_add_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
cpu
{
namespace
device
{
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
AddDataType
,
typename
AddDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ConvolutionForwardBlockLoopOverSpecialization_t
BlockLoopOverSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerThread
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
,
bool
UseCLocalBuffer
,
bool
BiasAlongGemmM
>
bool
BiasAlongGemmM
>
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
OutElementwiseOperation
>
{
{
using
DeviceOp
=
using
DeviceOp
=
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
InDataType
;
using
ADataType
=
InDataType
;
using
BDataType
=
WeiDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
OutDataType
;
using
CDataType
=
OutDataType
;
using
C0DataType
=
BiasDataType
;
using
C0DataType
=
BiasDataType
;
using
C1DataType
=
AddDataType
;
using
C1DataType
=
AddDataType
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
// TODO make A/B datatype different
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
auto
GetBlockMNKAccessOrder
()
static
constexpr
auto
GetBlockMNKAccessOrder
()
{
{
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
if
constexpr
(
BlockLoopOverSpecialization
==
DefaultBlockLoopOver
||
BlockLoopOverSpecialization
==
LoopOver_MNK
)
BlockLoopOverSpecialization
==
LoopOver_MNK
)
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
return
ck
::
Sequence
<
0
,
1
,
2
>
{};
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
else
if
constexpr
(
BlockLoopOverSpecialization
==
LoopOver_MKN
)
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
return
ck
::
Sequence
<
0
,
2
,
1
>
{};
}
}
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
using
BlockMNKAccessOrder
=
decltype
(
GetBlockMNKAccessOrder
());
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
NonTemporalStore
>
{};
}
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
NonTemporalStore
>
{};
}
}
else
else
{
{
// static_assert(false, "invalid Mr/Nr");
// static_assert(false, "invalid Mr/Nr");
}
}
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
constexpr
auto
GetInputBlockDescriptor
()
static
constexpr
auto
GetInputBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
if
constexpr
(
UseALocalBuffer
)
}
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
static
constexpr
auto
GetWeightBlockDescriptor
()
}
{
else
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
{
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
return
AGridDesc
{};
KPerBlock
,
}
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
static
constexpr
auto
GetOutputBlockDescriptor
()
{
{
if
constexpr
(
UseBLocalBuffer
)
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
{
}
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
KPerBlock
,
{
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
ck
::
index_t
gemm_n_padded
=
}
math
::
integer_least_multiple
(
gemm_n
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
else
const
auto
wei_gemm_n_k_grid_desc
=
{
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
return
BGridDesc
{};
}
const
auto
wei_gemm_padn_k_grid_desc
=
transform_tensor_descriptor
(
}
wei_gemm_n_k_grid_desc
,
make_tuple
(
make_right_pad_transform
(
gemm_n
,
gemm_n_padded
-
gemm_n
),
static
constexpr
auto
GetOutputBlockDescriptor
()
make_pass_through_transform
(
gemm_k
)),
{
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
if
constexpr
(
UseCLocalBuffer
)
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
const
auto
wei_gemm_n0_k_n1_grid_desc
=
transform_tensor_descriptor
(
}
wei_gemm_padn_k_grid_desc
,
else
ck
::
make_tuple
(
{
ck
::
make_unmerge_transform
(
return
CGridDesc
{};
ck
::
make_tuple
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I0
)
/
}
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
}
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
ck
::
make_pass_through_transform
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I1
))),
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
{
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
index_t
gemm_n_padded
=
math
::
integer_least_multiple
(
gemm_n
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
return
wei_gemm_n0_k_n1_grid_desc
;
const
auto
wei_gemm_n_k_grid_desc
=
}
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
,
gemm_k
));
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
const
auto
wei_gemm_padn_k_grid_desc
=
transform_tensor_descriptor
(
{
wei_gemm_n_k_grid_desc
,
const
auto
out_gemm_m_n_grid_desc
=
make_tuple
(
make_right_pad_transform
(
gemm_n
,
gemm_n_padded
-
gemm_n
),
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
make_pass_through_transform
(
gemm_k
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
return
out_gemm_m_n_grid_desc
;
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}));
}
const
auto
wei_gemm_n0_k_n1_grid_desc
=
transform_tensor_descriptor
(
static
auto
MakeBiasTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
wei_gemm_padn_k_grid_desc
,
{
ck
::
make_tuple
(
if
constexpr
(
BiasAlongGemmM
)
ck
::
make_unmerge_transform
(
{
ck
::
make_tuple
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I0
)
/
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
));
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
}
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
)),
else
ck
::
make_pass_through_transform
(
wei_gemm_padn_k_grid_desc
.
GetLength
(
I1
))),
{
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
));
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
}
}
return
wei_gemm_n0_k_n1_grid_desc
;
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
ck
::
index_t
C
,
{
ck
::
index_t
gemm_m
,
const
auto
out_gemm_m_n_grid_desc
=
ck
::
index_t
gemm_k
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
return
out_gemm_m_n_grid_desc
;
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
}
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
static
auto
MakeBiasTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
{
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
if
constexpr
(
BiasAlongGemmM
)
{
{
const
index_t
Wi
=
input_spatial_lengths
[
0
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
));
const
index_t
Wo
=
output_spatial_lengths
[
0
];
}
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
else
{
if
constexpr
(
ConvForwardSpecialization
==
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
));
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
}
{
}
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
return
in_gemm_m_k_grid_desc
;
ck
::
index_t
C
,
}
ck
::
index_t
gemm_m
,
else
if
constexpr
(
ConvForwardSpecialization
==
ck
::
index_t
gemm_k
,
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
{
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
auto
in_n_wi_c_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
in_n_wi_c_grid_desc
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
make_tuple
(
make_pass_through_transform
(
N
),
{
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
Wi
=
input_spatial_lengths
[
0
];
make_pass_through_transform
(
C
)),
const
index_t
Wo
=
output_spatial_lengths
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
if
constexpr
(
ConvForwardSpecialization
==
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
in_n_wo_c_grid_desc
,
{
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
const
auto
in_gemm_m_k_grid_desc
=
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
return
in_gemm_m_k_grid_desc
;
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
{
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
auto
in_n_wi_c_grid_desc
=
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
const
auto
in_n_wi_c_grid_desc
=
make_tuple
(
make_pass_through_transform
(
N
),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
in_n_wi_c_grid_desc
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
make_pass_through_transform
(
C
)),
in_n_wo_c_grid_desc
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
return
in_gemm_m_k_grid_desc
;
make_tuple
(
}
make_pass_through_transform
(
N
),
else
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
{
make_pass_through_transform
(
C
)),
const
index_t
X
=
filter_spatial_lengths
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
const
auto
in_n_wi_c_grid_desc
=
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
return
in_gemm_m_k_grid_desc
;
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
}
make_pass_through_transform
(
C
)),
}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
ck
::
index_t
C
,
in_n_wip_c_grid_desc
,
ck
::
index_t
gemm_m
,
make_tuple
(
ck
::
index_t
gemm_k
,
make_pass_through_transform
(
N
),
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
make_pass_through_transform
(
C
)),
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
{
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
const
index_t
Hi
=
input_spatial_lengths
[
0
];
make_merge_transform
(
make_tuple
(
X
,
C
))),
const
index_t
Wi
=
input_spatial_lengths
[
1
];
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
return
in_gemm_m_k_grid_desc
;
}
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
}
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
if
constexpr
(
ConvForwardSpecialization
==
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ck
::
index_t
C
,
{
ck
::
index_t
gemm_m
,
const
auto
in_gemm_m_k_grid_desc
=
ck
::
index_t
gemm_k
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
return
in_gemm_m_k_grid_desc
;
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
}
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
else
if
constexpr
(
ConvForwardSpecialization
==
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
{
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
const
auto
in_n_hi_wi_c_grid_desc
=
{
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
const
index_t
Ho
=
output_spatial_lengths
[
0
];
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
Wo
=
output_spatial_lengths
[
1
];
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
make_pass_through_transform
(
C
)),
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
const
auto
in_gemm_m_k_grid_desc
=
{
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
const
auto
in_gemm_m_k_grid_desc
=
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
return
in_gemm_m_k_grid_desc
;
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
ConvForwardSpecialization
==
return
in_gemm_m_k_grid_desc
;
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
}
{
else
const
auto
in_n_hi_wi_c_grid_desc
=
{
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
make_pass_through_transform
(
C
)),
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
const
auto
in_n_hi_wi_c_grid_desc
=
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
return
in_gemm_m_k_grid_desc
;
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
}
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
else
make_pass_through_transform
(
C
)),
{
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
const
index_t
Y
=
filter_spatial_lengths
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
in_n_hip_wip_c_grid_desc
,
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
const
index_t
InRightPadH
=
input_right_pads
[
0
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
const
index_t
InRightPadW
=
input_right_pads
[
1
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_hi_wi_c_grid_desc
=
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
in_n_hi_wi_c_grid_desc
,
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
return
in_gemm_m_k_grid_desc
;
make_pass_through_transform
(
C
)),
}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
}
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
in_n_hip_wip_c_grid_desc
,
ck
::
index_t
C
,
make_tuple
(
ck
::
index_t
gemm_m
,
make_pass_through_transform
(
N
),
ck
::
index_t
gemm_k
,
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
ck
::
index_t
gemm_m_pad
,
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
make_pass_through_transform
(
C
)),
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
{
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
const
index_t
Di
=
input_spatial_lengths
[
0
];
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
const
index_t
Hi
=
input_spatial_lengths
[
1
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
Wi
=
input_spatial_lengths
[
2
];
return
in_gemm_m_k_grid_desc
;
const
index_t
Do
=
output_spatial_lengths
[
0
];
}
const
index_t
Ho
=
output_spatial_lengths
[
1
];
}
const
index_t
Wo
=
output_spatial_lengths
[
2
];
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
ck
::
index_t
C
,
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
if
constexpr
(
ConvForwardSpecialization
==
ck
::
index_t
gemm_m_pad
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
{
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
auto
in_gemm_m_k_grid_desc
=
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
return
in_gemm_m_k_grid_desc
;
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
}
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
else
if
constexpr
(
ConvForwardSpecialization
==
{
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
const
index_t
Di
=
input_spatial_lengths
[
0
];
{
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
auto
in_n_di_hi_wi_c_grid_desc
=
const
index_t
Wi
=
input_spatial_lengths
[
2
];
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
Ho
=
output_spatial_lengths
[
1
];
in_n_di_hi_wi_c_grid_desc
,
const
index_t
Wo
=
output_spatial_lengths
[
2
];
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
make_pass_through_transform
(
C
)),
make_tuple
(
if
constexpr
(
ConvForwardSpecialization
==
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
make_tuple
(
{
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
return
in_gemm_m_k_grid_desc
;
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
}
make_pass_through_transform
(
C
)),
else
if
constexpr
(
ConvForwardSpecialization
==
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
{
const
auto
in_n_di_hi_wi_c_grid_desc
=
return
in_gemm_m_k_grid_desc
;
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
}
else
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
{
in_n_di_hi_wi_c_grid_desc
,
const
index_t
Z
=
filter_spatial_lengths
[
0
];
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
Y
=
filter_spatial_lengths
[
1
];
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
const
index_t
X
=
filter_spatial_lengths
[
2
];
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
make_pass_through_transform
(
C
)),
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
make_tuple
(
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
const
index_t
InRightPadD
=
input_right_pads
[
0
];
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
const
index_t
InRightPadH
=
input_right_pads
[
1
];
make_pass_through_transform
(
C
)),
const
index_t
InRightPadW
=
input_right_pads
[
2
];
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
return
in_gemm_m_k_grid_desc
;
}
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
else
in_n_di_hi_wi_c_grid_desc
,
{
make_tuple
(
make_pass_through_transform
(
N
),
const
index_t
Z
=
filter_spatial_lengths
[
0
];
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
const
index_t
Y
=
filter_spatial_lengths
[
1
];
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
const
index_t
X
=
filter_spatial_lengths
[
2
];
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
make_tuple
(
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
in_n_hip_wip_c_grid_desc
,
make_tuple
(
const
index_t
InRightPadD
=
input_right_pads
[
0
];
make_pass_through_transform
(
N
),
const
index_t
InRightPadH
=
input_right_pads
[
1
];
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
const
index_t
InRightPadW
=
input_right_pads
[
2
];
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_pass_through_transform
(
C
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
0
>
{},
in_n_di_hi_wi_c_grid_desc
,
Sequence
<
1
,
2
>
{},
make_tuple
(
make_pass_through_transform
(
N
),
Sequence
<
3
,
4
>
{},
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
Sequence
<
5
,
6
>
{},
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
Sequence
<
7
>
{}));
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
make_tuple
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
return
in_gemm_m_k_grid_desc
;
make_tuple
(
}
make_pass_through_transform
(
N
),
}
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
{
make_pass_through_transform
(
C
)),
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
make_tuple
(
std
::
end
(
output_spatial_lengths
),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
1
,
make_tuple
(
Sequence
<
0
>
{},
std
::
multiplies
<
ck
::
index_t
>
());
Sequence
<
1
,
2
>
{},
}
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
Sequence
<
7
>
{}));
{
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
std
::
end
(
filter_spatial_lengths
),
in_n_z_do_y_ho_x_wo_c_grid_desc
,
1
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
std
::
multiplies
<
ck
::
index_t
>
());
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
}
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
return
in_gemm_m_k_grid_desc
;
// return ck::math::integer_least_multiple(K,
}
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
}
return
K
;
}
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
{
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
ck
::
index_t
K
,
std
::
end
(
output_spatial_lengths
),
ck
::
index_t
C
,
1
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
multiplies
<
ck
::
index_t
>
());
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
}
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
{
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
end
(
filter_spatial_lengths
),
{
1
,
using
namespace
ck
;
std
::
multiplies
<
ck
::
index_t
>
());
}
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
GetGemmN
(
K
);
static
index_t
GetGemmN
(
ck
::
index_t
K
)
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
{
// return ck::math::integer_least_multiple(K,
// A:
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
const
auto
in_gemm_m_k_grid_desc
=
return
K
;
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
}
C
,
GemmM
,
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
GemmK
,
ck
::
index_t
K
,
input_spatial_lengths
,
ck
::
index_t
C
,
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
input_right_pads
);
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
// B:
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
auto
wei_gemm_n0_k_n1_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
{
// C:
using
namespace
ck
;
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
return
make_tuple
(
const
index_t
GemmN
=
GetGemmN
(
K
);
in_gemm_m_k_grid_desc
,
wei_gemm_n0_k_n1_grid_desc
,
out_gemm_m_n_grid_desc
);
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
}
// A:
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
const
auto
in_gemm_m_k_grid_desc
=
static
auto
GetABCGridDesc
()
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
{
C
,
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
GemmM
,
}
GemmK
,
input_spatial_lengths
,
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
filter_spatial_lengths
,
static
auto
GetABCGridDesc
()
output_spatial_lengths
,
{
conv_filter_strides
,
return
MakeABCGridDescriptor
(
conv_filter_dilations
,
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
input_left_pads
,
}
input_right_pads
);
// B:
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
const
auto
wei_gemm_n0_k_n1_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
static
auto
GetABCGridDesc
()
// C:
{
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
return
make_tuple
(
}
in_gemm_m_k_grid_desc
,
wei_gemm_n0_k_n1_grid_desc
,
out_gemm_m_n_grid_desc
);
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
static
auto
GetABCGridDesc
()
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
{
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
}
using
C1GridDesc
=
CGridDesc
;
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
// static constexpr bool UseCLocalBuffer = false;
static
auto
GetABCGridDesc
()
{
using
AThreadwiseCopy
=
return
MakeABCGridDescriptor
(
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
ADataType
,
}
ADataType
,
AGridDesc
,
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
decltype
(
GetInputBlockDescriptor
()),
static
auto
GetABCGridDesc
()
InElementwiseOperation
,
{
false
,
return
MakeABCGridDescriptor
(
ConvForwardSpecialization
,
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
GemmKSpecialization
>
;
}
using
BThreadwiseCopy
=
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
<
BDataType
,
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
BDataType
,
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
BGridDesc
,
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
decltype
(
GetWeightBlockDescriptor
()),
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
WeiElementwiseOperation
,
using
C1GridDesc
=
CGridDesc
;
false
,
ConvForwardSpecialization
,
// static constexpr bool UseCLocalBuffer = false;
GemmKSpecialization
>
;
using
AThreadwiseCopy
=
using
CThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
ADataType
,
CDataType
,
ADataType
,
C0DataType
,
AGridDesc
,
C1DataType
,
decltype
(
GetInputBlockDescriptor
()),
CDataType
,
InElementwiseOperation
,
CGridDesc
,
!
UseALocalBuffer
,
C0GridDesc
,
ConvForwardSpecialization
,
C1GridDesc
,
GemmKSpecialization
>
;
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
using
BThreadwiseCopy
=
!
UseCLocalBuffer
,
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
<
BiasAlongGemmM
>
;
BDataType
,
BDataType
,
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
BGridDesc
,
ADataType
,
// InDataType,
decltype
(
GetWeightBlockDescriptor
()),
BDataType
,
// WeiDataType,
WeiElementwiseOperation
,
CDataType
,
// OutDataType,
!
UseBLocalBuffer
,
C0DataType
,
// C0DataType
ConvForwardSpecialization
,
C1DataType
,
// C1DataType
GemmKSpecialization
>
;
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
using
CThreadwiseCopy
=
CGridDesc
,
// CGridDesc,
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
C0GridDesc
,
// C0GridDesc,
CDataType
,
C1GridDesc
,
// C1GridDesc,
C0DataType
,
AElementwiseOperation
,
// AElementwiseOperation,
C1DataType
,
BElementwiseOperation
,
// BElementwiseOperation,
CDataType
,
CElementwiseOperation
,
// CElementwiseOperation,
CGridDesc
,
MPerBlock
,
// MPerBlock,
C0GridDesc
,
NPerBlock
,
// NPerBlock,
C1GridDesc
,
KPerBlock
,
// KPerBlock,
decltype
(
GetOutputBlockDescriptor
()),
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
OutElementwiseOperation
,
AThreadwiseCopy
,
// AThreadwiseCopy
!
UseCLocalBuffer
,
BThreadwiseCopy
,
// BThreadwiseCopy
BiasAlongGemmM
>
;
CThreadwiseCopy
,
// CThreadwiseCopy
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseGemmBiasActivationAddAvx2_MxN
<
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
ADataType
,
// InDataType,
UseALocalBuffer
,
// UseALocalBuffer
BDataType
,
// WeiDataType,
UseBLocalBuffer
,
// UseBLocalBuffer
CDataType
,
// OutDataType,
UseCLocalBuffer
// UseCLocalBuffer
C0DataType
,
// C0DataType
>
;
C1DataType
,
// C1DataType
AGridDesc
,
// AGridDesc,
// Argument
BGridDesc
,
// BGridDesc,
struct
Argument
:
public
BaseArgument
CGridDesc
,
// CGridDesc,
{
C0GridDesc
,
// C0GridDesc,
Argument
(
const
InDataType
*
p_in_grid
,
C1GridDesc
,
// C1GridDesc,
const
WeiDataType
*
p_wei_grid
,
AElementwiseOperation
,
// AElementwiseOperation,
OutDataType
*
p_out_grid
,
BElementwiseOperation
,
// BElementwiseOperation,
const
BiasDataType
*
p_bias_grid
,
CElementwiseOperation
,
// CElementwiseOperation,
const
AddDataType
*
p_add_grid
,
MPerBlock
,
// MPerBlock,
ck
::
index_t
N
,
NPerBlock
,
// NPerBlock,
ck
::
index_t
K
,
KPerBlock
,
// KPerBlock,
ck
::
index_t
C
,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
AThreadwiseCopy
,
// AThreadwiseCopy
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
BThreadwiseCopy
,
// BThreadwiseCopy
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
CThreadwiseCopy
,
// CThreadwiseCopy
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
BlockMNKAccessOrder
,
// BlockMNKAccessOrder,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
UseALocalBuffer
,
// UseALocalBuffer
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
UseBLocalBuffer
,
// UseBLocalBuffer
InElementwiseOperation
in_element_op
,
UseCLocalBuffer
// UseCLocalBuffer
WeiElementwiseOperation
wei_element_op
,
>
;
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_in_grid
},
// Argument
p_b_grid_
{
p_wei_grid
},
struct
Argument
:
public
BaseArgument
p_c_grid_
{
p_out_grid
},
{
p_c0_grid_
{
p_bias_grid
},
Argument
(
const
InDataType
*
p_in_grid
,
p_c1_grid_
{
p_add_grid
},
const
WeiDataType
*
p_wei_grid
,
a_grid_desc_
{},
OutDataType
*
p_out_grid
,
b_grid_desc_
{},
const
BiasDataType
*
p_bias_grid
,
c_grid_desc_
{},
const
AddDataType
*
p_add_grid
,
c0_grid_desc_
{},
ck
::
index_t
N
,
c1_grid_desc_
{},
ck
::
index_t
K
,
a_element_op_
{
in_element_op
},
ck
::
index_t
C
,
b_element_op_
{
wei_element_op
},
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
c_element_op_
{
out_element_op
},
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
Conv_N_
{
N
},
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
Conv_K_
{
K
},
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
Conv_C_
{
C
},
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
filter_spatial_lengths_
{
filter_spatial_lengths
},
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
conv_filter_strides_
{
conv_filter_strides
},
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
input_left_pads_
{
input_left_pads
},
InElementwiseOperation
in_element_op
,
input_right_pads_
{
input_right_pads
}
WeiElementwiseOperation
wei_element_op
,
{
OutElementwiseOperation
out_element_op
)
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
:
p_a_grid_
{
p_in_grid
},
K
,
p_b_grid_
{
p_wei_grid
},
C
,
p_c_grid_
{
p_out_grid
},
input_spatial_lengths
,
p_c0_grid_
{
p_bias_grid
},
filter_spatial_lengths
,
p_c1_grid_
{
p_add_grid
},
output_spatial_lengths
,
a_grid_desc_
{},
conv_filter_strides
,
b_grid_desc_
{},
conv_filter_dilations
,
c_grid_desc_
{},
input_left_pads
,
c0_grid_desc_
{},
input_right_pads
);
c1_grid_desc_
{},
a_grid_desc_
=
descs
[
I0
];
a_element_op_
{
in_element_op
},
b_grid_desc_
=
descs
[
I1
];
b_element_op_
{
wei_element_op
},
c_grid_desc_
=
descs
[
I2
];
c_element_op_
{
out_element_op
},
Conv_N_
{
N
},
c0_grid_desc_
=
DeviceOp
::
MakeBiasTensorDescriptor
(
GetGemmM
(
N
,
output_spatial_lengths
),
Conv_K_
{
K
},
GetGemmN
(
K
));
Conv_C_
{
C
},
c1_grid_desc_
=
descs
[
I2
];
filter_spatial_lengths_
{
filter_spatial_lengths
},
}
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
// private:
input_right_pads_
{
input_right_pads
}
const
ADataType
*
p_a_grid_
;
{
const
BDataType
*
p_b_grid_
;
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
CDataType
*
p_c_grid_
;
K
,
const
C0DataType
*
p_c0_grid_
;
C
,
const
C1DataType
*
p_c1_grid_
;
input_spatial_lengths
,
AGridDesc
a_grid_desc_
;
filter_spatial_lengths
,
BGridDesc
b_grid_desc_
;
output_spatial_lengths
,
CGridDesc
c_grid_desc_
;
conv_filter_strides
,
C0GridDesc
c0_grid_desc_
;
conv_filter_dilations
,
C1GridDesc
c1_grid_desc_
;
input_left_pads
,
input_right_pads
);
AElementwiseOperation
a_element_op_
;
a_grid_desc_
=
descs
[
I0
];
BElementwiseOperation
b_element_op_
;
b_grid_desc_
=
descs
[
I1
];
CElementwiseOperation
c_element_op_
;
c_grid_desc_
=
descs
[
I2
];
// for checking IsSupportedArgument()
index_t
Conv_N_
;
c0_grid_desc_
=
DeviceOp
::
MakeBiasTensorDescriptor
(
GetGemmM
(
N
,
output_spatial_lengths
),
index_t
Conv_K_
;
GetGemmN
(
K
));
index_t
Conv_C_
;
c1_grid_desc_
=
descs
[
I2
];
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
}
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
// private:
std
::
vector
<
index_t
>
input_right_pads_
;
const
ADataType
*
p_a_grid_
;
};
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
// Invoker
const
C0DataType
*
p_c0_grid_
;
struct
Invoker
:
public
BaseInvoker
const
C1DataType
*
p_c1_grid_
;
{
AGridDesc
a_grid_desc_
;
using
Argument
=
DeviceOp
::
Argument
;
BGridDesc
b_grid_desc_
;
CGridDesc
c_grid_desc_
;
float
Run
(
const
Argument
&
arg
,
C0GridDesc
c0_grid_desc_
;
const
StreamConfig
&
stream_config
=
StreamConfig
{},
C1GridDesc
c1_grid_desc_
;
int
nrepeat
=
1
)
{
AElementwiseOperation
a_element_op_
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
BElementwiseOperation
b_element_op_
;
{
CElementwiseOperation
c_element_op_
;
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
// for checking IsSupportedArgument()
}
index_t
Conv_N_
;
index_t
Conv_K_
;
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
index_t
Conv_C_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
const
auto
kernel
=
std
::
vector
<
index_t
>
conv_filter_strides_
;
ck
::
cpu
::
kernel_gemm_bias_activation_add_avx_mxn
<
GridwiseGemm
,
std
::
vector
<
index_t
>
input_left_pads_
;
ADataType
,
std
::
vector
<
index_t
>
input_right_pads_
;
BDataType
,
};
CDataType
,
C0DataType
,
// Invoker
C1DataType
,
struct
Invoker
:
public
BaseInvoker
AGridDesc
,
{
BGridDesc
,
using
Argument
=
DeviceOp
::
Argument
;
CGridDesc
,
C0GridDesc
,
float
Run
(
const
Argument
&
arg
,
C1GridDesc
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
AElementwiseOperation
,
int
nrepeat
=
1
)
BElementwiseOperation
,
{
CElementwiseOperation
>
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
float
ave_time
=
0
;
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
nrepeat
,
arg
.
p_a_grid_
,
const
auto
kernel
=
arg
.
p_b_grid_
,
ck
::
cpu
::
kernel_gemm_bias_activation_add_avx_mxn
<
GridwiseGemm
,
arg
.
p_c_grid_
,
ADataType
,
arg
.
p_c0_grid_
,
BDataType
,
arg
.
p_c1_grid_
,
CDataType
,
arg
.
a_grid_desc_
,
C0DataType
,
arg
.
b_grid_desc_
,
C1DataType
,
arg
.
c_grid_desc_
,
AGridDesc
,
arg
.
c0_grid_desc_
,
BGridDesc
,
arg
.
c1_grid_desc_
,
CGridDesc
,
arg
.
a_element_op_
,
C0GridDesc
,
arg
.
b_element_op_
,
C1GridDesc
,
arg
.
c_element_op_
);
AElementwiseOperation
,
BElementwiseOperation
,
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
CElementwiseOperation
>
;
// result
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
float
ave_time
=
0
;
launch_cpu_kernel
(
kernel
,
if
(
nrepeat
!=
1
)
arg
.
p_a_grid_
,
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
arg
.
p_b_grid_
,
nrepeat
,
arg
.
p_c_grid_
,
arg
.
p_a_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c1_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
p_c0_grid_
,
arg
.
b_grid_desc_
,
arg
.
p_c1_grid_
,
arg
.
c_grid_desc_
,
arg
.
a_grid_desc_
,
arg
.
c0_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c1_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
a_element_op_
,
arg
.
c0_grid_desc_
,
arg
.
b_element_op_
,
arg
.
c1_grid_desc_
,
arg
.
c_element_op_
);
arg
.
a_element_op_
,
arg
.
b_element_op_
,
return
ave_time
;
arg
.
c_element_op_
);
}
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
float
Run
(
const
BaseArgument
*
p_arg
,
// result
const
StreamConfig
&
stream_config
=
StreamConfig
{},
memset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_
.
GetElementSpaceSize
());
int
nrepeat
=
1
)
override
{
launch_cpu_kernel
(
kernel
,
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
arg
.
p_a_grid_
,
}
arg
.
p_b_grid_
,
};
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
static
constexpr
bool
IsValidCompilationParameter
()
arg
.
p_c1_grid_
,
{
arg
.
a_grid_desc_
,
// TODO: properly implement this check
arg
.
b_grid_desc_
,
return
true
;
arg
.
c_grid_desc_
,
}
arg
.
c0_grid_desc_
,
arg
.
c1_grid_desc_
,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
arg
.
a_element_op_
,
{
arg
.
b_element_op_
,
if
constexpr
(
ConvForwardSpecialization
==
arg
.
c_element_op_
);
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
return
ave_time
;
// check if it's 1x1, stride=1 conv
}
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
float
Run
(
const
BaseArgument
*
p_arg
,
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
const
StreamConfig
&
stream_config
=
StreamConfig
{},
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
int
nrepeat
=
1
)
override
{
{
return
false
;
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
}
}
}
};
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// check if it's 1x1 conv
// TODO: properly implement this check
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
return
true
;
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
}
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
return
false
;
{
}
if
constexpr
(
ConvForwardSpecialization
==
}
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
if
constexpr
(
GemmKSpecialization
==
// check if it's 1x1, stride=1 conv
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
{
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
return
false
;
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
}
{
return
false
;
// Gridwise GEMM size
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
// check if it's 1x1 conv
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
}
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
{
const
WeiDataType
*
p_wei_grid
,
return
false
;
OutDataType
*
p_out_grid
,
}
const
BiasDataType
*
p_bias_grid
,
}
const
AddDataType
*
p_add_grid
,
ck
::
index_t
N
,
if
constexpr
(
GemmKSpecialization
==
ck
::
index_t
K
,
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ck
::
index_t
C
,
ConvForwardSpecialization
!=
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
{
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
return
false
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
}
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
if
constexpr
((
!
UseALocalBuffer
||
!
UseBLocalBuffer
)
&&
InElementwiseOperation
in_element_op
,
ConvForwardSpecialization
!=
WeiElementwiseOperation
wei_element_op
,
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
OutElementwiseOperation
out_element_op
)
{
{
// TODO: We can support this in the future, as long as figure out how to express tensor
return
Argument
{
p_in_grid
,
// transform
p_wei_grid
,
return
false
;
p_out_grid
,
}
p_bias_grid
,
p_add_grid
,
// Gridwise GEMM size
N
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
K
,
}
C
,
input_spatial_lengths
,
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
filter_spatial_lengths
,
{
output_spatial_lengths
,
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
conv_filter_strides
,
}
conv_filter_dilations
,
input_left_pads
,
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
input_right_pads
,
const
WeiDataType
*
p_wei_grid
,
in_element_op
,
OutDataType
*
p_out_grid
,
wei_element_op
,
const
BiasDataType
*
p_bias_grid
,
out_element_op
};
const
AddDataType
*
p_add_grid
,
}
ck
::
index_t
N
,
ck
::
index_t
K
,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
unique_ptr
<
BaseArgument
>
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
const
void
*
p_wei_grid
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
void
*
p_out_grid
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
const
void
*
p_bias_grid
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
const
void
*
p_add_grid
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
N
,
InElementwiseOperation
in_element_op
,
ck
::
index_t
K
,
WeiElementwiseOperation
wei_element_op
,
ck
::
index_t
C
,
OutElementwiseOperation
out_element_op
)
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
{
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
return
Argument
{
p_in_grid
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
p_wei_grid
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
p_out_grid
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
p_bias_grid
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
p_add_grid
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
N
,
InElementwiseOperation
in_element_op
,
K
,
WeiElementwiseOperation
wei_element_op
,
C
,
OutElementwiseOperation
out_element_op
)
override
input_spatial_lengths
,
{
filter_spatial_lengths
,
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
output_spatial_lengths
,
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
conv_filter_strides
,
static_cast
<
OutDataType
*>
(
p_out_grid
),
conv_filter_dilations
,
static_cast
<
const
BiasDataType
*>
(
p_bias_grid
),
input_left_pads
,
static_cast
<
const
AddDataType
*>
(
p_add_grid
),
input_right_pads
,
N
,
in_element_op
,
K
,
wei_element_op
,
C
,
out_element_op
};
input_spatial_lengths
,
}
filter_spatial_lengths
,
output_spatial_lengths
,
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
conv_filter_strides
,
conv_filter_dilations
,
std
::
unique_ptr
<
BaseArgument
>
input_left_pads
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
input_right_pads
,
const
void
*
p_wei_grid
,
in_element_op
,
void
*
p_out_grid
,
wei_element_op
,
const
void
*
p_bias_grid
,
out_element_op
);
const
void
*
p_add_grid
,
}
ck
::
index_t
N
,
ck
::
index_t
K
,
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
ck
::
index_t
C
,
{
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
}
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
string
GetTypeString
()
const
override
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
{
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
auto
str
=
std
::
stringstream
();
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
InElementwiseOperation
in_element_op
,
if
(
is_local_buffer
)
WeiElementwiseOperation
wei_element_op
,
return
"L"
;
OutElementwiseOperation
out_element_op
)
override
else
{
return
"G"
;
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
};
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
// clang-format off
static_cast
<
OutDataType
*>
(
p_out_grid
),
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
static_cast
<
const
BiasDataType
*>
(
p_bias_grid
),
<<
"DFwd_BAA_Avx2_NHWC_KYXC"
static_cast
<
const
AddDataType
*>
(
p_add_grid
),
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
N
,
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
K
,
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
C
,
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
input_spatial_lengths
,
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
filter_spatial_lengths
,
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
output_spatial_lengths
,
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
conv_filter_strides
,
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
conv_filter_dilations
,
;
input_left_pads
,
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
input_right_pads
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
in_element_op
,
{
wei_element_op
,
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
out_element_op
);
}
}
// clang-format on
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
return
str
.
str
();
{
}
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
};
}
}
// namespace device
std
::
string
GetTypeString
()
const
override
}
// namespace cpu
{
}
// namespace tensor_operation
auto
str
=
std
::
stringstream
();
}
// namespace ck
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
#endif
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DFwd_BAA_Avx2_NHWC_KYXC"
<<
"_FS"
<<
static_cast
<
int
>
(
ConvForwardSpecialization
)
<<
"_KS"
<<
static_cast
<
int
>
(
GemmKSpecialization
)
<<
"_BS"
<<
static_cast
<
int
>
(
BlockLoopOverSpecialization
)
<<
"_BT"
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/device/device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
View file @
71254ddd
...
@@ -116,20 +116,41 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -116,20 +116,41 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
static
constexpr
auto
GetInputBlockDescriptor
()
static
constexpr
auto
GetInputBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
if
constexpr
(
UseALocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
KPerBlock
));
}
else
{
return
AGridDesc
{};
}
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
static
constexpr
auto
GetWeightBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
if
constexpr
(
UseBLocalBuffer
)
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
{
KPerBlock
,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
math
::
integer_divide_ceil
(
NPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
KPerBlock
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
return
BGridDesc
{};
}
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
static
constexpr
auto
GetOutputBlockDescriptor
()
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
MPerBlock
,
NPerBlock
));
}
else
{
return
CGridDesc
{};
}
}
}
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
...
@@ -563,7 +584,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -563,7 +584,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
AGridDesc
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
InElementwiseOperation
,
false
,
!
UseALocalBuffer
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
GemmKSpecialization
>
;
...
@@ -574,7 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -574,7 +595,7 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
BGridDesc
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
WeiElementwiseOperation
,
false
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
,
ConvForwardSpecialization
,
GemmKSpecialization
>
;
GemmKSpecialization
>
;
...
@@ -820,7 +841,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -820,7 +841,9 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
}
}
if
constexpr
(
GemmKSpecialization
==
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
KPerBlock
==
0
))
return
false
;
return
false
;
...
@@ -829,6 +852,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
...
@@ -829,6 +852,15 @@ struct DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Ou
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
return
false
;
return
false
;
if
constexpr
(
!
UseALocalBuffer
&&
ConvForwardSpecialization
!=
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// TODO: We can support this in the future, as long as figure out how to express tensor
// transform
return
false
;
}
// Gridwise GEMM size
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
}
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
71254ddd
...
@@ -80,46 +80,65 @@ struct GridwiseGemmAvx2_MxN
...
@@ -80,46 +80,65 @@ struct GridwiseGemmAvx2_MxN
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
{
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
if
constexpr
(
UseALocalBuffer
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
// A : M, K
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
auto
a_block_desc_m_k
=
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_blk
));
{
return
a_block_desc_m_k
;
// A : M, K
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_blk
));
return
a_block_desc_m_k
;
}
else
{
// A : K, M
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
return
a_block_desc_k_m
;
}
}
}
else
else
{
{
// A : K, M
return
a_grid_desc
;
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
return
a_block_desc_k_m
;
}
}
}
}
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
BGridDesc
&
b_grid_desc
)
{
{
// n_per_blk should be 8x
if
constexpr
(
UseBLocalBuffer
)
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
// B : K, N
// n_per_blk should be 8x
auto
b_block_desc_k_n
=
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
return
b_block_desc_k_n
;
{
// B : K, N
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
return
b_block_desc_k_n
;
}
else
{
// B : N/8, K, N8
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
}
else
else
{
{
// B : N/8, K, N8
return
b_grid_desc
;
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
}
}
...
@@ -262,10 +281,10 @@ struct GridwiseGemmAvx2_MxN
...
@@ -262,10 +281,10 @@ struct GridwiseGemmAvx2_MxN
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpre
t_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
cons
t_cast
<
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpre
t_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
cons
t_cast
<
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
...
@@ -274,8 +293,8 @@ struct GridwiseGemmAvx2_MxN
...
@@ -274,8 +293,8 @@ struct GridwiseGemmAvx2_MxN
FloatA
,
// FloatA,
FloatA
,
// FloatA,
FloatB
,
// FloatB,
FloatB
,
// FloatB,
FloatC
,
// FloatC,
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
KPerBlock
,
// KPerBlock,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
...
@@ -320,14 +339,14 @@ struct GridwiseGemmAvx2_MxN
...
@@ -320,14 +339,14 @@ struct GridwiseGemmAvx2_MxN
auto
a_threadwise_copy
=
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
AElementwiseOperation
{});
auto
b_threadwise_copy
=
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
BElementwiseOperation
{});
...
@@ -338,21 +357,27 @@ struct GridwiseGemmAvx2_MxN
...
@@ -338,21 +357,27 @@ struct GridwiseGemmAvx2_MxN
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
DeviceAlignedMemCPU
a_block_mem
(
MemAlignmentByte
);
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
MemAlignmentByte
);
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
...
@@ -395,8 +420,8 @@ struct GridwiseGemmAvx2_MxN
...
@@ -395,8 +420,8 @@ struct GridwiseGemmAvx2_MxN
{
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_grid_buf
,
...
@@ -412,12 +437,17 @@ struct GridwiseGemmAvx2_MxN
...
@@ -412,12 +437,17 @@ struct GridwiseGemmAvx2_MxN
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetASliceLength
(
mc_size
,
kc_size
),
b_block_desc
,
b_block_desc
,
b_block_buf
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBSliceLength
(
kc_size
,
nc_size
),
c_block_desc
,
c_block_desc
,
c_block_buf
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
i_kc
!=
0
);
i_kc
!=
0
);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
if
((
i_kc
+
k_per_block
)
<
GemmK
)
...
@@ -450,14 +480,14 @@ struct GridwiseGemmAvx2_MxN
...
@@ -450,14 +480,14 @@ struct GridwiseGemmAvx2_MxN
auto
a_threadwise_copy
=
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
AThreadwiseCopy
(
a_grid_desc
,
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
AElementwiseOperation
{});
auto
b_threadwise_copy
=
auto
b_threadwise_copy
=
BThreadwiseCopy
(
b_grid_desc
,
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
BElementwiseOperation
{});
BElementwiseOperation
{});
...
@@ -468,21 +498,27 @@ struct GridwiseGemmAvx2_MxN
...
@@ -468,21 +498,27 @@ struct GridwiseGemmAvx2_MxN
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
make_zero_multi_index
<
2
>
(),
CElementwiseOperation
{});
CElementwiseOperation
{});
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
DeviceAlignedMemCPU
a_block_mem
(
MemAlignmentByte
);
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
MemAlignmentByte
);
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
...
@@ -503,7 +539,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -503,7 +539,7 @@ struct GridwiseGemmAvx2_MxN
{
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_grid_buf
,
a_block_desc
,
a_block_desc
,
...
@@ -519,7 +555,7 @@ struct GridwiseGemmAvx2_MxN
...
@@ -519,7 +555,7 @@ struct GridwiseGemmAvx2_MxN
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
nc_size
=
math
::
integer_least_multiple
(
nc_size
=
math
::
integer_least_multiple
(
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
b_grid_buf
,
...
@@ -543,12 +579,18 @@ struct GridwiseGemmAvx2_MxN
...
@@ -543,12 +579,18 @@ struct GridwiseGemmAvx2_MxN
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetASliceLength
(
mc_size
,
kc_size
),
b_block_desc
,
b_block_desc
,
b_block_buf
,
b_block_buf
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBSliceLength
(
kc_size
,
nc_size
),
c_block_desc
,
c_block_desc
,
c_block_buf
,
c_block_buf
,
make_zero_multi_index
<
2
>
(),
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
i_kc
!=
0
);
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
if
((
i_nc
+
n_per_block
)
<
GemmN
)
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_bias_activation_add_avx2.hpp
View file @
71254ddd
#ifndef CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#ifndef CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#define CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#define CK_GRIDWISE_GEMM_BIAS_ACTIVATION_ADD_AVX2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "blockwise_gemm_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <utility>
#include <utility>
#include <unistd.h>
#include <unistd.h>
#include <omp.h>
#include <omp.h>
#include <pthread.h>
#include <pthread.h>
namespace
ck
{
namespace
ck
{
namespace
cpu
{
namespace
cpu
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatC1
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
typename
CGridDesc
,
typename
C0GridDesc
,
typename
C0GridDesc
,
typename
C1GridDesc
,
typename
C1GridDesc
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
void
kernel_gemm_bias_activation_add_avx_mxn
(
const
FloatA
*
__restrict__
p_a_grid
,
void
kernel_gemm_bias_activation_add_avx_mxn
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
const
AGridDesc
&
a_grid_desc
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
C0GridDesc
&
c0_grid_desc
,
const
C0GridDesc
&
c0_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
CElementwiseOperation
&
c_element_op
)
{
{
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_c0_grid
,
p_c0_grid
,
p_c1_grid
,
p_c1_grid
,
a_grid_desc
,
a_grid_desc
,
b_grid_desc
,
b_grid_desc
,
c_grid_desc
,
c_grid_desc
,
c0_grid_desc
,
c0_grid_desc
,
c1_grid_desc
,
c1_grid_desc
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
}
}
template
<
typename
FloatA
,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC0
,
typename
FloatC1
,
typename
FloatC1
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
typename
CGridDesc
,
typename
C0GridDesc
,
typename
C0GridDesc
,
typename
C1GridDesc
,
typename
C1GridDesc
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
MPerBlock
,
// block means data are designed to fit in cache (L1/L2/L3)
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
typename
ThreadwiseGemm_Dispatch
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
BlockMNKAccessOrder
,
// how we accss gemm MNK to better fit in cache
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer (need CThreadwiseCopy).
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
// if false, will write to C directly (no need CThreadwiseCopy)
>
>
struct
GridwiseGemmBiasActivationAddAvx2_MxN
struct
GridwiseGemmBiasActivationAddAvx2_MxN
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
{
const
ck
::
index_t
k_per_blk
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
const
AGridDesc
&
a_grid_desc
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
if
constexpr
(
UseALocalBuffer
)
// A : M, K
{
auto
a_block_desc_m_k
=
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_blk
));
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
return
a_block_desc_m_k
;
{
}
// A : M, K
else
auto
a_block_desc_m_k
=
{
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_blk
));
// A : K, M
return
a_block_desc_m_k
;
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
}
make_tuple
(
k_per_blk
,
else
math
::
integer_least_multiple
(
{
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
// A : K, M
return
a_block_desc_k_m
;
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
}
make_tuple
(
k_per_blk
,
}
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
return
a_block_desc_k_m
;
{
}
// n_per_blk should be 8x
}
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
else
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
return
a_grid_desc
;
// B : K, N
}
auto
b_block_desc_k_n
=
}
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
return
b_block_desc_k_n
;
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
}
const
ck
::
index_t
n_per_blk
,
else
const
BGridDesc
&
b_grid_desc
)
{
{
// B : N/8, K, N8
if
constexpr
(
UseBLocalBuffer
)
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
{
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
// n_per_blk should be 8x
k_per_blk
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
return
b_block_desc_n0_k_n1
;
{
}
// B : K, N
}
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
return
b_block_desc_k_n
;
const
ck
::
index_t
n_per_blk
,
}
const
CGridDesc
&
c_grid_desc
)
else
{
{
if
constexpr
(
UseCLocalBuffer
)
// B : N/8, K, N8
{
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
make_tuple
(
math
::
integer_divide_ceil
(
}
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
else
k_per_blk
,
return
c_grid_desc
;
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
return
b_block_desc_n0_k_n1
;
}
static
auto
GetASliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
}
{
else
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
{
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
return
b_grid_desc
;
{
}
// A : M, K
}
return
ck
::
make_multi_index
(
m_per_blk
,
k_per_blk
);
}
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
else
const
ck
::
index_t
n_per_blk
,
{
const
CGridDesc
&
c_grid_desc
)
// A : K, M
{
return
ck
::
make_multi_index
(
if
constexpr
(
UseCLocalBuffer
)
k_per_blk
,
{
math
::
integer_least_multiple
(
m_per_blk
,
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
));
}
}
else
}
return
c_grid_desc
;
}
static
auto
GetBSliceLength
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
static
auto
GetASliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
// n_per_blk should be 8x
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
// B : K, N
// A : M, K
return
ck
::
make_multi_index
(
return
ck
::
make_multi_index
(
m_per_blk
,
k_per_blk
);
k_per_blk
,
}
math
::
integer_least_multiple
(
n_per_blk
,
else
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
{
}
// A : K, M
else
return
ck
::
make_multi_index
(
{
k_per_blk
,
// B : N/8, K, N8
math
::
integer_least_multiple
(
m_per_blk
,
return
ck
::
make_multi_index
(
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
));
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
}
k_per_blk
,
}
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
static
auto
GetBSliceLength
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
}
{
// n_per_blk should be 8x
static
auto
GetCSliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
{
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
{
}
// B : K, N
return
ck
::
make_multi_index
(
static
auto
GetAIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_k
)
k_per_blk
,
{
math
::
integer_least_multiple
(
n_per_blk
,
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
}
{
else
// A : M, K
{
return
ck
::
make_multi_index
(
i_m
,
i_k
);
// B : N/8, K, N8
}
return
ck
::
make_multi_index
(
else
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
{
k_per_blk
,
// A : K, M
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
return
ck
::
make_multi_index
(
i_k
,
i_m
);
}
}
}
}
static
auto
GetCSliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
static
auto
GetBIndex
(
const
ck
::
index_t
i_k
,
const
ck
::
index_t
i_n
)
{
{
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
// i_n should be 8x
}
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
static
auto
GetAIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_k
)
{
{
// B : K, N
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
return
ck
::
make_multi_index
(
i_k
,
i_n
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
}
{
else
// A : M, K
{
return
ck
::
make_multi_index
(
i_m
,
i_k
);
// B : N/8, K, N8
}
return
ck
::
make_multi_index
(
i_n
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
else
i_k
,
{
i_n
%
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
// A : K, M
}
return
ck
::
make_multi_index
(
i_k
,
i_m
);
}
}
}
static
auto
GetCIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_n
)
{
static
auto
GetBIndex
(
const
ck
::
index_t
i_k
,
const
ck
::
index_t
i_n
)
return
ck
::
make_multi_index
(
i_m
,
i_n
);
{
}
// i_n should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
const
BGridDesc
&
b_grid_desc
,
{
const
CGridDesc
&
c_grid_desc
)
// B : K, N
{
return
ck
::
make_multi_index
(
i_k
,
i_n
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
}
bool
is_valid
=
true
;
else
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
{
if
constexpr
(
UseCLocalBuffer
)
// B : N/8, K, N8
{
return
ck
::
make_multi_index
(
i_n
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
i_k
,
is_valid
&=
false
;
i_n
%
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
}
else
}
{
// TODO: need check c grid is simple transform?
static
auto
GetCIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_n
)
if
(
GemmN
%
8
!=
0
)
{
is_valid
&=
false
;
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
}
return
is_valid
;
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
CGridDesc
&
c_grid_desc
)
const
FloatB
*
__restrict__
p_b_grid
,
{
FloatC
*
__restrict__
p_c_grid
,
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
const
FloatC0
*
__restrict__
p_c0_grid
,
bool
is_valid
=
true
;
const
FloatC1
*
__restrict__
p_c1_grid
,
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
AGridDesc
&
a_grid_desc
,
if
constexpr
(
UseCLocalBuffer
)
const
BGridDesc
&
b_grid_desc
,
{
const
CGridDesc
&
c_grid_desc
,
if
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
&&
NPerBlock
<
GemmN
)
const
C0GridDesc
&
c0_grid_desc
,
is_valid
&=
false
;
const
C1GridDesc
&
c1_grid_desc
,
}
const
AElementwiseOperation
&
a_element_op
,
else
const
BElementwiseOperation
&
b_element_op
,
{
const
CElementwiseOperation
&
c_element_op
)
// TODO: need check c grid is simple transform?
{
if
(
GemmN
%
8
!=
0
)
ck
::
index_t
m_per_block
=
MPerBlock
;
is_valid
&=
false
;
ck
::
index_t
n_per_block
=
NPerBlock
;
}
ck
::
index_t
k_per_block
=
KPerBlock
;
return
is_valid
;
}
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
constexpr
auto
a_block_copy_dim
=
AGridDesc
::
GetNumOfDimension
();
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
const
CGridDesc
&
c_grid_desc
,
reinterpret_cast
<
const
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
const
C0GridDesc
&
c0_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
const
AElementwiseOperation
&
a_element_op
,
reinterpret_cast
<
const
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
{
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
ck
::
index_t
m_per_block
=
MPerBlock
;
ck
::
index_t
n_per_block
=
NPerBlock
;
auto
c0_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
ck
::
index_t
k_per_block
=
KPerBlock
;
reinterpret_cast
<
const
FloatC0
*>
(
p_c0_grid
),
c0_grid_desc
.
GetElementSpaceSize
());
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
auto
c1_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
reinterpret_cast
<
const
FloatC1
*>
(
p_c1_grid
),
c1_grid_desc
.
GetElementSpaceSize
());
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
constexpr
auto
a_block_copy_dim
=
AGridDesc
::
GetNumOfDimension
();
FloatA
,
// FloatA,
FloatB
,
// FloatB,
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
)),
// ABlockDesc,
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
)),
// BBlockDesc,
const_cast
<
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
KPerBlock
,
// KPerBlock,
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
const_cast
<
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
// gemm MN to utilize micro kernel>{};
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
int
total_threads
=
omp_get_max_threads
();
auto
c0_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
#if 0
reinterpret_cast
<
const
FloatC0
*>
(
p_c0_grid
),
c0_grid_desc
.
GetElementSpaceSize
());
if(total_threads > 1){
#pragma omp parallel
auto
c1_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
{
reinterpret_cast
<
const
FloatC1
*>
(
p_c1_grid
),
c1_grid_desc
.
GetElementSpaceSize
());
int tid = omp_get_thread_num();
cpu_set_t set;
auto
blockwise_gemm
=
BlockwiseGemmAvx2_MxN
<
CPU_ZERO(&set);
FloatA
,
// FloatA,
FloatB
,
// FloatB,
CPU_SET(tid, &set);
FloatC
,
// FloatC,
decltype
(
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
)),
// ABlockDesc,
if (sched_setaffinity(0, sizeof(set), &set) == -1) {
decltype
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
)),
// BBlockDesc,
throw std::runtime_error("wrong! fail to set thread affinity");
decltype
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
)),
// CBlockDesc,
}
KPerBlock
,
// KPerBlock,
}
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
}
ThreadMNAccessOrder
>
{};
// ThreadMNAccessOrder // how we acces
#endif
// gemm MN to utilize micro kernel>{};
// TODO: openmp aware ordering
int
total_threads
=
omp_get_max_threads
();
//
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
#if 0
{
if(total_threads > 1){
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
#pragma omp parallel
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
{
int tid = omp_get_thread_num();
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
cpu_set_t set;
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
CPU_ZERO(&set);
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
const
ck
::
index_t
grids_per_thread
=
CPU_SET(tid, &set);
math
::
integer_divide_ceil
(
grid_size
,
total_threads
);
if (sched_setaffinity(0, sizeof(set), &set) == -1) {
// This version does not consider K panel re-usage. simple for openmp
throw std::runtime_error("wrong! fail to set thread affinity");
#pragma omp parallel
}
{
}
auto
a_threadwise_copy
=
}
AThreadwiseCopy
(
a_grid_desc
,
#endif
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
// TODO: openmp aware ordering
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
//
AElementwiseOperation
{});
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
1
,
2
>>::
value
)
{
auto
b_threadwise_copy
=
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
BThreadwiseCopy
(
b_grid_desc
,
auto
b_move_k_step
=
GetBIndex
(
k_per_block
,
0
);
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
const
ck
::
index_t
grid_n
=
math
::
integer_divide_ceil
(
GemmN
,
n_per_block
);
BElementwiseOperation
{});
const
ck
::
index_t
grid_size
=
grid_m
*
grid_n
;
const
ck
::
index_t
grids_per_thread
=
auto
c_threadwise_copy
=
math
::
integer_divide_ceil
(
grid_size
,
total_threads
);
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
ck
::
make_zero_multi_index
<
2
>
(),
// This version does not consider K panel re-usage. simple for openmp
c_grid_desc
,
#pragma omp parallel
ck
::
make_zero_multi_index
<
2
>
(),
{
CElementwiseOperation
{});
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
MemAlignmentByte
);
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
MemAlignmentByte
);
AElementwiseOperation
{});
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
auto
b_threadwise_copy
=
MemAlignmentByte
);
BThreadwiseCopy
(
b_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
BElementwiseOperation
{});
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
c_threadwise_copy
=
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
ck
::
make_zero_multi_index
<
2
>
(),
c_grid_desc
,
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
ck
::
make_zero_multi_index
<
2
>
(),
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
CElementwiseOperation
{});
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
DeviceAlignedMemCPU
a_block_mem
(
:
c_grid_desc
.
GetElementSpaceSize
());
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
const
ck
::
index_t
tid
=
omp_get_thread_num
();
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
for
(
ck
::
index_t
i_gpt
=
0
;
i_gpt
<
grids_per_thread
;
i_gpt
++
)
MemAlignmentByte
);
{
DeviceAlignedMemCPU
c_block_mem
(
ck
::
index_t
gid
=
i_gpt
*
total_threads
+
tid
;
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
if
(
gid
>=
grid_size
)
MemAlignmentByte
);
break
;
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
ck
::
index_t
i_mc
=
(
gid
/
grid_n
)
*
m_per_block
;
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
ck
::
index_t
i_nc
=
(
gid
%
grid_n
)
*
n_per_block
;
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
:
a_grid_desc
.
GetElementSpaceSize
());
ck
::
index_t
nc_size
=
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
nc_size
=
math
::
integer_least_multiple
(
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
:
b_grid_desc
.
GetElementSpaceSize
());
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
0
,
i_nc
));
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_threadwise_copy
.
SetSrc1SliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
c_threadwise_copy
.
SetSrc2SliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
:
c_grid_desc
.
GetElementSpaceSize
());
if
constexpr
(
!
UseCLocalBuffer
)
{
const
ck
::
index_t
tid
=
omp_get_thread_num
();
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunRead
(
c_grid_desc
,
for
(
ck
::
index_t
i_gpt
=
0
;
i_gpt
<
grids_per_thread
;
i_gpt
++
)
c_grid_buf
,
{
c0_grid_desc
,
ck
::
index_t
gid
=
i_gpt
*
total_threads
+
tid
;
c0_grid_buf
,
if
(
gid
>=
grid_size
)
c1_grid_desc
,
break
;
c1_grid_buf
,
c_block_desc
,
ck
::
index_t
i_mc
=
(
gid
/
grid_n
)
*
m_per_block
;
c_block_buf
,
ck
::
index_t
i_nc
=
(
gid
%
grid_n
)
*
n_per_block
;
GetCSliceLength
(
mc_size
,
nc_size
));
}
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
ck
::
index_t
nc_size
=
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
{
nc_size
=
math
::
integer_least_multiple
(
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
0
,
i_nc
));
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
a_grid_buf
,
a_block_desc
,
c_threadwise_copy
.
SetSrc1SliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
a_block_buf
,
c_threadwise_copy
.
SetSrc2SliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
GetASliceLength
(
mc_size
,
kc_size
));
if
constexpr
(
!
UseCLocalBuffer
)
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
{
b_grid_buf
,
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
b_block_desc
,
c_threadwise_copy
.
RunRead
(
c_grid_desc
,
b_block_buf
,
c_grid_buf
,
GetBSliceLength
(
kc_size
,
nc_size
));
c0_grid_desc
,
c0_grid_buf
,
blockwise_gemm
.
Run
(
a_block_desc
,
c1_grid_desc
,
a_block_buf
,
c1_grid_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
c_block_desc
,
b_block_desc
,
c_block_buf
,
b_block_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
make_zero_multi_index
<
b_block_copy_dim
>
(),
}
c_block_desc
,
c_block_buf
,
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
make_zero_multi_index
<
2
>
(),
{
i_kc
!=
0
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
{
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
}
a_grid_buf
,
}
a_block_desc
,
a_block_buf
,
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
GetCIndex
(
i_mc
,
i_nc
));
GetASliceLength
(
mc_size
,
kc_size
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
c_block_buf
,
b_grid_buf
,
c0_grid_desc
,
b_block_desc
,
c0_grid_buf
,
b_block_buf
,
c1_grid_desc
,
GetBSliceLength
(
kc_size
,
nc_size
));
c1_grid_buf
,
c_grid_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
c_grid_buf
,
a_block_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
make_zero_multi_index
<
a_block_copy_dim
>
(),
}
GetASliceLength
(
mc_size
,
kc_size
),
}
}
b_block_desc
,
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
b_block_buf
,
{
make_zero_multi_index
<
b_block_copy_dim
>
(),
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
GetBSliceLength
(
kc_size
,
nc_size
),
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
c_block_desc
,
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
c_block_buf
,
const
ck
::
index_t
grid_m_per_thread
=
math
::
integer_divide_ceil
(
grid_m
,
total_threads
);
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
),
// only parallel in gemm m dim
#pragma omp parallel
i_kc
!=
0
);
{
auto
a_threadwise_copy
=
if
((
i_kc
+
k_per_block
)
<
GemmK
)
AThreadwiseCopy
(
a_grid_desc
,
{
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
GetABlockDescriptor
(
m_per_block
,
k_per_block
),
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
}
AElementwiseOperation
{});
}
auto
b_threadwise_copy
=
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
GetCIndex
(
i_mc
,
i_nc
));
BThreadwiseCopy
(
b_grid_desc
,
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
c_block_buf
,
GetBBlockDescriptor
(
k_per_block
,
n_per_block
),
c0_grid_desc
,
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
c0_grid_buf
,
BElementwiseOperation
{});
c1_grid_desc
,
c1_grid_buf
,
auto
c_threadwise_copy
=
c_grid_desc
,
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
c_grid_buf
,
ck
::
make_zero_multi_index
<
2
>
(),
GetCSliceLength
(
mc_size
,
nc_size
));
c_grid_desc
,
}
ck
::
make_zero_multi_index
<
2
>
(),
}
CElementwiseOperation
{});
}
else
if
constexpr
(
std
::
is_same
<
BlockMNKAccessOrder
,
ck
::
Sequence
<
0
,
2
,
1
>>::
value
)
DeviceAlignedMemCPU
a_block_mem
(
m_per_block
*
k_per_block
*
sizeof
(
FloatA
),
{
MemAlignmentByte
);
auto
a_move_k_step
=
GetAIndex
(
0
,
k_per_block
);
DeviceAlignedMemCPU
b_block_mem
(
k_per_block
*
n_per_block
*
sizeof
(
FloatB
),
auto
b_move_k_step
=
GetBIndex
(
0
,
n_per_block
);
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
const
ck
::
index_t
grid_m
=
math
::
integer_divide_ceil
(
GemmM
,
m_per_block
);
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
const
ck
::
index_t
grid_m_per_thread
=
math
::
integer_divide_ceil
(
grid_m
,
total_threads
);
MemAlignmentByte
);
// only parallel in gemm m dim
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
#pragma omp parallel
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
),
{
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
));
auto
a_threadwise_copy
=
AThreadwiseCopy
(
a_grid_desc
,
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
),
GetABlockDescriptor
(
m_per_block
,
k_per_block
,
a_grid_desc
),
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
));
ck
::
make_zero_multi_index
<
a_block_copy_dim
>
(),
AElementwiseOperation
{});
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
auto
b_threadwise_copy
=
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
BThreadwiseCopy
(
b_grid_desc
,
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
:
c_grid_desc
.
GetElementSpaceSize
());
GetBBlockDescriptor
(
k_per_block
,
n_per_block
,
b_grid_desc
),
ck
::
make_zero_multi_index
<
b_block_copy_dim
>
(),
const
ck
::
index_t
tid
=
omp_get_thread_num
();
BElementwiseOperation
{});
for
(
ck
::
index_t
i_gmpt
=
0
;
i_gmpt
<
grid_m_per_thread
;
i_gmpt
++
)
auto
c_threadwise_copy
=
{
CThreadwiseCopy
(
GetCBlockDescriptor
(
m_per_block
,
n_per_block
,
c_grid_desc
),
ck
::
index_t
i_mc
=
(
i_gmpt
*
total_threads
+
tid
)
*
m_per_block
;
ck
::
make_zero_multi_index
<
2
>
(),
if
(
i_mc
>=
GemmM
)
c_grid_desc
,
break
;
ck
::
make_zero_multi_index
<
2
>
(),
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
CElementwiseOperation
{});
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
DeviceAlignedMemCPU
a_block_mem
(
{
UseALocalBuffer
?
m_per_block
*
k_per_block
*
sizeof
(
FloatA
)
:
0
,
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
UseBLocalBuffer
?
k_per_block
*
n_per_block
*
sizeof
(
FloatB
)
:
0
,
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
MemAlignmentByte
);
a_grid_buf
,
DeviceAlignedMemCPU
c_block_mem
(
a_block_desc
,
UseCLocalBuffer
?
(
m_per_block
*
n_per_block
*
sizeof
(
FloatC
))
:
0
,
a_block_buf
,
MemAlignmentByte
);
GetASliceLength
(
mc_size
,
kc_size
));
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
i_kc
,
0
));
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
// TODO: if use local C buffer, then this nc loop need to loop only once
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
:
a_grid_desc
.
GetElementSpaceSize
());
{
ck
::
index_t
nc_size
=
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
nc_size
=
math
::
integer_least_multiple
(
:
const_cast
<
FloatB
*>
(
p_b_grid
),
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
:
b_grid_desc
.
GetElementSpaceSize
());
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
b_grid_buf
,
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
b_block_desc
,
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
b_block_buf
,
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
GetBSliceLength
(
kc_size
,
nc_size
));
:
c_grid_desc
.
GetElementSpaceSize
());
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
const
ck
::
index_t
tid
=
omp_get_thread_num
();
c_threadwise_copy
.
SetSrc1SliceOrigin
(
c_block_desc
,
for
(
ck
::
index_t
i_gmpt
=
0
;
i_gmpt
<
grid_m_per_thread
;
i_gmpt
++
)
GetCIndex
(
i_mc
,
i_nc
));
{
c_threadwise_copy
.
SetSrc2SliceOrigin
(
c_block_desc
,
ck
::
index_t
i_mc
=
(
i_gmpt
*
total_threads
+
tid
)
*
m_per_block
;
GetCIndex
(
i_mc
,
i_nc
));
if
(
i_mc
>=
GemmM
)
break
;
if
constexpr
(
!
UseCLocalBuffer
)
ck
::
index_t
mc_size
=
ck
::
math
::
min
(
GemmM
-
i_mc
,
m_per_block
);
{
a_threadwise_copy
.
SetSrcSliceOrigin
(
a_grid_desc
,
GetAIndex
(
i_mc
,
0
));
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
GetCIndex
(
i_mc
,
i_nc
));
{
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
c_threadwise_copy
.
RunRead
(
c_grid_desc
,
c_grid_buf
,
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
,
a_grid_desc
);
c0_grid_desc
,
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
c0_grid_buf
,
a_grid_buf
,
c1_grid_desc
,
a_block_desc
,
c1_grid_buf
,
a_block_buf
,
c_block_desc
,
GetASliceLength
(
mc_size
,
kc_size
));
c_block_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
GetBIndex
(
i_kc
,
0
));
}
// TODO: if use local C buffer, then this nc loop need to loop only once
blockwise_gemm
.
Run
(
a_block_desc
,
for
(
ck
::
index_t
i_nc
=
0
;
i_nc
<
GemmN
;
i_nc
+=
n_per_block
)
a_block_buf
,
{
make_zero_multi_index
<
a_block_copy_dim
>
(),
ck
::
index_t
nc_size
=
b_block_desc
,
ck
::
math
::
min
(
GemmN
-
i_nc
,
n_per_block
);
// TODO: nc need be 8x
b_block_buf
,
nc_size
=
math
::
integer_least_multiple
(
make_zero_multi_index
<
b_block_copy_dim
>
(),
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
c_block_desc
,
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
,
b_grid_desc
);
c_block_buf
,
make_zero_multi_index
<
2
>
(),
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
i_kc
!=
0
);
b_grid_buf
,
b_block_desc
,
if
((
i_nc
+
n_per_block
)
<
GemmN
)
b_block_buf
,
{
GetBSliceLength
(
kc_size
,
nc_size
));
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
auto
c_block_desc
=
GetCBlockDescriptor
(
mc_size
,
nc_size
,
c_grid_desc
);
if
constexpr
(
UseCLocalBuffer
)
c_threadwise_copy
.
SetSrc1SliceOrigin
(
c_block_desc
,
{
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
c_threadwise_copy
.
SetSrc2SliceOrigin
(
c_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
if
constexpr
(
!
UseCLocalBuffer
)
c_block_buf
,
{
c0_grid_desc
,
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
c0_grid_buf
,
GetCIndex
(
i_mc
,
i_nc
));
c1_grid_desc
,
c1_grid_buf
,
c_threadwise_copy
.
RunRead
(
c_grid_desc
,
c_grid_desc
,
c_grid_buf
,
c_grid_buf
,
c0_grid_desc
,
GetCSliceLength
(
mc_size
,
nc_size
));
c0_grid_buf
,
}
c1_grid_desc
,
else
c1_grid_buf
,
{
c_block_desc
,
// only write for last K, since the RunWrite here is just doing
c_block_buf
,
// elementwise op from global to global
GetCSliceLength
(
mc_size
,
nc_size
));
if
((
i_kc
+
k_per_block
)
>=
GemmK
)
}
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
GetCIndex
(
i_mc
,
i_nc
));
a_block_buf
,
make_zero_multi_index
<
a_block_copy_dim
>
(),
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
GetASliceLength
(
mc_size
,
kc_size
),
c_block_buf
,
c0_grid_desc
,
b_block_desc
,
c0_grid_buf
,
b_block_buf
,
c1_grid_desc
,
make_zero_multi_index
<
b_block_copy_dim
>
(),
c1_grid_buf
,
GetBSliceLength
(
kc_size
,
nc_size
),
c_grid_desc
,
c_grid_buf
,
c_block_desc
,
GetCSliceLength
(
mc_size
,
nc_size
));
c_block_buf
,
}
make_zero_multi_index
<
2
>
(),
}
GetCSliceLength
(
mc_size
,
nc_size
),
}
i_kc
!=
0
);
if
((
i_kc
+
k_per_block
)
<
GemmK
)
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
}
{
}
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
}
}
}
if
constexpr
(
UseCLocalBuffer
)
};
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
}
// namespace cpu
GetCIndex
(
i_mc
,
i_nc
));
}
// namespace ck
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
#endif
c_block_buf
,
c0_grid_desc
,
c0_grid_buf
,
c1_grid_desc
,
c1_grid_buf
,
c_grid_desc
,
c_grid_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if
((
i_kc
+
k_per_block
)
>=
GemmK
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
GetCIndex
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c0_grid_desc
,
c0_grid_buf
,
c1_grid_desc
,
c1_grid_buf
,
c_grid_desc
,
c_grid_buf
,
GetCSliceLength
(
mc_size
,
nc_size
));
}
}
}
if
((
i_kc
+
k_per_block
)
<
GemmK
)
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_move_k_step
);
}
}
}
}
}
};
}
// namespace cpu
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
71254ddd
...
@@ -519,7 +519,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -519,7 +519,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
src_desc
,
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
const
SliceLengths
&
slice_length
)
...
@@ -917,14 +917,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
...
@@ -917,14 +917,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXC
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
void
RunRead
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
const
SliceLengths
&
slice_length
)
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
// TODO: weight NHWC not support this
// KYXC weigh should not support this
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
}
else
else
{
{
...
@@ -1132,12 +1133,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
...
@@ -1132,12 +1133,15 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunRead
(
const
SrcDesc
&
,
void
RunRead
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
const
SliceLengths
&
slice_length
)
{
{
if
constexpr
(
BypassTransfer
)
{}
if
constexpr
(
BypassTransfer
)
{
dst_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
}
else
else
{
{
const
ck
::
index_t
n0_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n0_per_block
=
slice_length
[
Number
<
0
>
{}];
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
71254ddd
...
@@ -47,121 +47,138 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
...
@@ -47,121 +47,138 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true, c_local_buf>
// clang-format on
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
56
,
24
,
256
,
4
,
24
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
56
,
24
,
256
,
4
,
24
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
71254ddd
...
@@ -40,121 +40,146 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
...
@@ -40,121 +40,146 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, true, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true, false, c_local_buf>
// clang-format on
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
56
,
24
,
256
,
4
,
24
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
120
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
false
)
>
;
// clang-format on
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
24
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
32
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
40
,
24
,
256
,
4
,
24
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
48
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
56
,
24
,
256
,
4
,
24
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
256
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxck8_nhwk
(
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
71254ddd
#include <stdlib.h>
#include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
cpu
{
namespace
device
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_add_avx2_instance
{
namespace
device_conv2d_fwd_bias_activation_add_avx2_instance
{
using
InType
=
float
;
using
InType
=
float
;
using
WeiType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
using
AccType
=
float
;
using
InLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// NHWC
using
InLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// NHWC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
static
constexpr
auto
ConvFwdDefault
=
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
,
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1S1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
,
bias_along_m
>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
// clang-format on
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, true , c_local_buf, bias_along_m>
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
// use this in single thread, but gemm_n is not multiple of 8
// clang-format off
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
=
std
::
tuple
<
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
// time no local c is better...)
// clang-format off
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
=
std
::
tuple
<
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
// clang-format on
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
{
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
{});
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
}
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c
(
// clang-format on
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
instances
,
{
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
{});
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
}
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_instances
{});
}
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_local_c
(
{
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
{
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
}
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_local_c_instances
{});
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
}
// namespace device
}
// namespace cpu
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_mt
(
}
// namespace tensor_operation
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
}
// namespace ck
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
View file @
71254ddd
...
@@ -40,69 +40,81 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
...
@@ -40,69 +40,81 @@ static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m) \
#define DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, c_local_buf, bias_along_m) \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
\
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf, bias_along_m>
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , true , c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1S1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, false, false, c_local_buf, bias_along_m>, \
DeviceConvNDFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K<float , float , float, float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, true , false, c_local_buf, bias_along_m>
// clang-format on
// clang-format on
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
false
,
false
)
>
;
// clang-format on
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_local_c_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
64
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
// time no local c is better...)
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
using
device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_f32_mt_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
24
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
32
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
40
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
24
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
48
,
48
,
256
,
4
,
24
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
56
,
24
,
256
,
4
,
24
,
false
,
false
),
// DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
16
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
72
,
32
,
256
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
96
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
32
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
120
,
64
,
128
,
6
,
16
,
false
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
256
,
128
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
128
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
240
,
128
,
4
,
24
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
512
,
256
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
768
,
320
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
896
,
352
,
128
,
6
,
16
,
true
,
false
),
DEVICE_CONV2D_FWD_BAA_AVX2_NHWC_KYXCK8_NHWK_F32
(
PT
,
PT
,
AddReluAdd
,
1024
,
416
,
128
,
6
,
16
,
true
,
false
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
void
add_device_conv2d_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment