Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
19c3c9a8
Commit
19c3c9a8
authored
Dec 04, 2019
by
Chao Liu
Browse files
updated bwd-data v1r1 and v2r1 to use gridwise gemm
parent
19a93dac
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
245 additions
and
240 deletions
+245
-240
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+59
-52
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+49
-39
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+8
-8
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+32
-25
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+61
-59
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+18
-14
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+18
-43
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
19c3c9a8
#ifndef CK_GRIDWISE_CONVOLUTION_
B
ACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_
B
ACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_
GemmN
ACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_
GemmN
ACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
...
...
@@ -17,11 +17,11 @@ template <index_t GridSize,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
E
PerBlock
,
index_t
B
PerBlock
,
index_t
KPerBlock
,
typename
In
LeftPads
,
typename
In
RightPads
,
index_t
GemmM
PerBlock
,
index_t
GemmN
PerBlock
,
index_t
Gemm
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
...
...
@@ -31,13 +31,15 @@ template <index_t GridSize,
index_t
GemmKPerThreadLoop
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
WeiBlockCopySubLengths_K_E
,
typename
WeiBlockCopyClusterLengths_K_E
,
index_t
WeiBlockCopyDataPerAccess_E
,
typename
OutBlockCopySubLengths_K_B
,
typename
OutBlockCopyClusterLengths_K_B
,
index_t
OutBlockCopyDataPerAccess_B
,
index_t
InThreadCopyDataPerAccess_B
>
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmN
,
index_t
GemmABlockCopyDstDataPerWrite_GemmN
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
...
...
@@ -49,8 +51,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
...
...
@@ -73,14 +73,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InThreadCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InThreadCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// TODO: this logic may not be correct for bwd-data
static_assert
(
(
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmCThreadCopyDstDataPerWrite_GemmN1
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmCThreadCopyDstDataPerWrite_GemmN1
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// output tensor
constexpr
auto
out_n_k_howo_global_desc
=
...
...
@@ -99,8 +98,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
@@ -121,33 +121,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
// GEMM: atomic add
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1r1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_k_e_global_desc
),
decltype
(
out_k_b_global_desc
),
decltype
(
in_e_b_global_desc
),
InMemoryDataOperation
::
atomic_add
,
EPerBlock
,
BPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
WeiBlockCopySubLengths_K_E
,
WeiBlockCopyClusterLengths_K_E
,
WeiBlockCopyDataPerAccess_E
,
OutBlockCopySubLengths_K_B
,
OutBlockCopyClusterLengths_K_B
,
OutBlockCopyDataPerAccess_B
,
InThreadCopyDataPerAccess_B
>
{};
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_k_e_global_desc
),
decltype
(
out_k_b_global_desc
),
decltype
(
in_e_b_global_desc
),
InMemoryDataOperation
::
atomic_add
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmN
,
GemmABlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
19c3c9a8
...
...
@@ -34,14 +34,15 @@ template <index_t GridSize,
index_t
GemmKPerThreadLoop
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopySubLengths
,
// Gemm-K, Gemm-M
typename
GemmABlockCopyClusterLengths
,
// Gemm-K, Gemm-M
index_t
GemmABlockCopyDataPerAccess
,
// Gemm-M
typename
GemmBBlockCopySubLengths
,
// Gemm-K, Gemm-N
typename
GemmBBlockCopyClusterLengths
,
// Gemm-K, Gemm-N
index_t
GemmBBlockCopyDataPerAccess
,
// Gemm-N
index_t
GemmCThreadCopyDataPerAccess
// Gemm-N
>
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmM
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
...
...
@@ -71,10 +72,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmCThreadCopyDataPerAccess
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmCThreadCopyDataPerAccess
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// TODO: this logic may not be correct for bwd-data
static_assert
(
(
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmCThreadCopyDstDataPerWrite_GemmN1
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmCThreadCopyDstDataPerWrite_GemmN1
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
...
...
@@ -172,33 +175,40 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1r1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopySubLengths
,
GemmABlockCopyClusterLengths
,
GemmABlockCopyDataPerAccess
,
GemmBBlockCopySubLengths
,
GemmBBlockCopyClusterLengths
,
GemmBBlockCopyDataPerAccess
,
GemmCThreadCopyDataPerAccess
>
{};
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
19c3c9a8
...
...
@@ -18,10 +18,10 @@ template <index_t BlockSize,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVecto
rAccess
Dim
,
index_t
DstVector
Access
Dim
,
index_t
SrcDataPer
Access
,
index_t
DstDataPer
Access
,
index_t
SrcVecto
Read
Dim
,
index_t
DstVector
Write
Dim
,
index_t
SrcDataPer
Read
,
index_t
DstDataPer
Write
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
generic
,
AddressSpace
ThreadBufferAddressSpace
=
AddressSpace
::
generic
,
AddressSpace
DstAddressSpace
=
AddressSpace
::
generic
,
...
...
@@ -146,8 +146,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadBufferDesc
,
ThreadSliceLengths
,
SrcDimAccessOrder
,
SrcVecto
rAccess
Dim
,
SrcDataPer
Access
,
SrcVecto
Read
Dim
,
SrcDataPer
Read
,
1
,
SrcAddressSpace
,
ThreadBufferAddressSpace
,
...
...
@@ -157,9 +157,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockDstDesc
,
ThreadSliceLengths
,
DstDimAccessOrder
,
DstVector
Access
Dim
,
DstVector
Write
Dim
,
1
,
DstDataPer
Access
,
DstDataPer
Write
,
ThreadBufferAddressSpace
,
DstAddressSpace
,
DstInMemOp
>
;
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
19c3c9a8
...
...
@@ -31,14 +31,21 @@ template <index_t GridSize,
index_t
KPerThreadLoop
,
index_t
ThreadGemmDataPerReadM
,
index_t
ThreadGemmDataPerReadN
,
typename
ABlockCopySubLengths_K_M
,
typename
ABlockCopyClusterLengths_K_M
,
index_t
ABlockCopyDataPerAccess_M
,
typename
BBlockCopySubLengths_K_N
,
typename
BBlockCopyClusterLengths_K_N
,
index_t
BBlockCopyDataPerAccess_N
,
index_t
CThreadCopyDataPerAccess_N
>
struct
GridwiseGemmTransposedANormalBNormalC_v1r1
typename
ABlockCopyThreadSliceLengths_K_M
,
typename
ABlockCopyThreadClusterLengths_K_M
,
typename
ABlockCopyThreadClusterArrangeOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
typename
BBlockCopyThreadSliceLengths_K_N
,
typename
BBlockCopyThreadClusterLengths_K_N
,
typename
BBlockCopyThreadClusterArrangeOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
index_t
CThreadCopyVectorReadWriteDim
,
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v1
{
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
...
...
@@ -55,8 +62,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
constexpr
auto
N
=
b_k_n_global_desc
.
GetLengths
()[
1
];
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDataPer
Access
_M
,
BBlockCopyDataPer
Access
_N
,
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyD
stD
ataPer
Write
_M
,
BBlockCopyD
stD
ataPer
Write
_N
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
);
...
...
@@ -86,15 +93,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
.
GetLengths
()),
ABlockCopySubLengths_K_M
,
ABlockCopyClusterLengths_K_M
,
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
ABlockCopySrcVectorReadDim
,
1
,
ABlockCopyDataPer
Access_M
,
ABlockCopyDataPer
Access
_M
,
ABlockCopy
Src
DataPer
Read
,
ABlockCopyD
stD
ataPer
Write
_M
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
...
...
@@ -112,15 +119,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
.
GetLengths
()),
BBlockCopySubLengths_K_N
,
BBlockCopyClusterLengths_K_N
,
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
BBlockCopySrcVectorReadDim
,
1
,
BBlockCopyDataPer
Access_N
,
BBlockCopyDataPer
Access
_N
,
BBlockCopy
Src
DataPer
Read
,
BBlockCopyD
stD
ataPer
Write
_N
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
...
...
@@ -305,9 +312,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CThreadCopyDataPerAccess_N
,
CThreadCopyDataPer
Access_N
,
CThreadCopyVectorReadWriteDim
,
1
,
CThreadCopyD
stD
ataPer
Write
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
,
CGlobalMemoryDataOperation
>
(
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
19c3c9a8
...
...
@@ -15,13 +15,15 @@ namespace ck {
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template
<
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
Vector
Access
Dim
,
index_t
SrcDataPer
Access
,
index_t
DstDataPer
Access
,
index_t
Vector
ReadWrite
Dim
,
index_t
SrcDataPer
Read
,
index_t
DstDataPer
Write
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
generic
,
AddressSpace
DstAddressSpace
=
AddressSpace
::
generic
,
InMemoryDataOperation
DstInMemOp
=
InMemoryDataOperation
::
none
>
...
...
@@ -45,10 +47,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>
{},
"wrong! map is not valid"
);
static_assert
(
SliceLengths
{}[
Vector
Access
Dim
]
%
math
::
lcm
(
SrcDataPer
Access
,
DstDataPer
Access
)
==
0
,
SliceLengths
{}[
Vector
ReadWrite
Dim
]
%
math
::
lcm
(
SrcDataPer
Read
,
DstDataPer
Write
)
==
0
,
"wrong! cannot evenly divide"
);
// TODO:: sanity-check if vectorized memory
access
is allowed on src and dst
// TODO:: sanity-check if vectorized memory
read/write
is allowed on src and dst
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v4r2
()
...
...
@@ -67,17 +69,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
mDstSliceOrigin
=
dst_slice_origin
;
}
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
Vector
Access
Dim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
Vector
ReadWrite
Dim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPer
Access
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPer
Access
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPer
Read
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPer
Write
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPer
Access
,
DstDataPer
Access
)
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPer
Read
,
DstDataPer
Write
)
>
{};
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
...
...
@@ -109,13 +109,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
auto
src_coord
=
mSrcSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check src
vector's padd
ing situation, only check the first data in this src
// Check src
data's valid mapp
ing situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the
same padd
ing situation
// has the
valid/invalid mapp
ing situation
if
(
src_coord
.
IsUpperIndexMappedToValidOffset
())
{
move_data
<
SrcData
,
SrcDataPer
Access
,
SrcDataPer
Read
,
SrcAddressSpace
,
AddressSpace
::
vgpr
,
InMemoryDataOperation
::
none
>
(
...
...
@@ -141,13 +141,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
auto
dst_coord
=
mDstSliceOrigin
+
(
long_vector_data_begin_id
+
scalar_id
);
// Check dst
vector's padd
ing situation, only check the first data in this dst
// Check dst
data's valid mapp
ing situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the
same padd
ing situation
// has the
valid/invalid mapp
ing situation
if
(
dst_coord
.
IsUpperIndexMappedToValidOffset
())
{
move_data
<
DstData
,
DstDataPer
Access
,
DstDataPer
Write
,
AddressSpace
::
vgpr
,
DstAddressSpace
,
DstInMemOp
>
(
...
...
@@ -165,20 +165,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
return
Sequence
<
(
Mask
?
Lengths
:
1
)...
>
{};
}
// Will do
padd
ing check on src data: Read 0 if src data
is in padding area.
// Will do
padd
ing check on dst data: No write if dst data
is in paddin area.
// Will do
valid mapp
ing check on src data: Read 0 if src data
has a invalid mapping
// Will do
valid mapp
ing check on dst data: No write if dst data
has a invalid mapping
// This version is optimized for address calculation of src tensor
// TODO: this function is not compiled to expected ISA
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run_optimized_src_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
Vector
Access
Dim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
Vector
ReadWrite
Dim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPer
Access
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPer
Access
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPer
Read
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPer
Write
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPer
Access
,
DstDataPer
Access
)
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPer
Read
,
DstDataPer
Write
)
>
{};
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
...
...
@@ -187,9 +187,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
src_linear_dim_mask
=
SrcDesc
::
GetLinearDimensionMask
();
constexpr
auto
src_nonlinear_dim_mask
=
SrcDesc
::
GetNonLinearDimensionMask
();
static_assert
(
src_linear_dim_mask
.
At
(
Vector
Access
Dim
)
||
long_vector_size
==
SrcDataPer
Access
,
"Warning! Vector
Access
Dim is not SrcDesc's linear dimension, performance "
static_assert
(
src_linear_dim_mask
.
At
(
Vector
ReadWrite
Dim
)
||
long_vector_size
==
SrcDataPer
Read
,
"Warning! Vector
ReadWrite
Dim is not SrcDesc's linear dimension, performance "
"would drop"
);
// separate steps into linear and non-linear components, accoording to src tensor
...
...
@@ -230,13 +230,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector
[
i
]
=
0
;
}
// Loop over Vector
Access
Dim, and load data from src to the
// Loop over Vector
ReadWrite
Dim, and load data from src to the
// long-vector buffer.
// If Vector
Access
Dim is src's linear dimension, then src's
// If Vector
ReadWrite
Dim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If
// Vector
Access
Dim is src's nonlinear dimension, then src's
// Vector
ReadWrite
Dim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best
// performance, Vector
Access
Dim, should be src's linear dimension
// performance, Vector
ReadWrite
Dim, should be src's linear dimension
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
...
...
@@ -258,13 +258,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
src_coord
.
GetOffset
()
-
src_nonlinear_coord
.
GetOffset
();
#endif
// Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsUpperIndexMappedToValidOffset
())
{
move_data
<
SrcData
,
SrcDataPer
Access
,
SrcDataPer
Read
,
SrcAddressSpace
,
AddressSpace
::
vgpr
,
InMemoryDataOperation
::
none
>
(
p_src
,
...
...
@@ -296,13 +297,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
auto
dst_coord
=
mDstSliceOrigin
+
(
nonlinear_dim_data_steps
+
linear_dim_data_steps
+
scalar_id
);
// Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if
(
dst_coord
.
IsUpperIndexMappedToValidOffset
())
{
move_data
<
DstData
,
DstDataPer
Access
,
DstDataPer
Write
,
AddressSpace
::
vgpr
,
DstAddressSpace
,
DstInMemOp
>
(
...
...
@@ -313,20 +315,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
});
}
// Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor
// TODO: this function is not compiled to expected ISA
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run_optimized_dst_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
Vector
Access
Dim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
Vector
ReadWrite
Dim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPer
Access
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPer
Access
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPer
Read
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPer
Write
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPer
Access
,
DstDataPer
Access
)
>
{};
constexpr
auto
long_vector_size
=
Number
<
math
::
lcm
(
SrcDataPer
Read
,
DstDataPer
Write
)
>
{};
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
...
...
@@ -335,9 +335,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
dst_linear_dim_mask
=
DstDesc
::
GetLinearDimensionMask
();
constexpr
auto
dst_nonlinear_dim_mask
=
DstDesc
::
GetNonLinearDimensionMask
();
static_assert
(
dst_linear_dim_mask
.
At
(
Vector
Access
Dim
)
||
long_vector_size
==
DstDataPer
Access
,
"Warning! Vector
Access
Dim is not DstDesc's linear dimension, performance "
static_assert
(
dst_linear_dim_mask
.
At
(
Vector
ReadWrite
Dim
)
||
long_vector_size
==
DstDataPer
Write
,
"Warning! Vector
ReadWrite
Dim is not DstDesc's linear dimension, performance "
"would drop"
);
// separate steps into linear and non-linear components, accoording to dst tensor
...
...
@@ -378,13 +378,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector
[
i
]
=
0
;
}
// Loop over Vector
Access
Dim, and load data from src to the
// Loop over Vector
ReadWrite
Dim, and load data from src to the
// long-vector buffer.
// If Vector
Access
Dim is dst's linear dimension, then dst's
// If Vector
ReadWrite
Dim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If
// Vector
Access
Dim is dst's nonlinear dimension, then dst's
// Vector
ReadWrite
Dim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best
// performance, Vector
Access
Dim, should be dst's linear dimension
// performance, Vector
ReadWrite
Dim, should be dst's linear dimension
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
...
...
@@ -397,13 +397,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const
auto
src_coord
=
mSrcSliceOrigin
+
(
nonlinear_dim_data_steps
+
linear_dim_data_steps
+
scalar_id
);
// Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation
// Check src data's valid mapping situation, only check the first data in this
// src
// vector. It's user's responsiblity to make sure all data in the src vector
// has the valid/invalid mapping situation
if
(
src_coord
.
IsUpperIndexMappedToValidOffset
())
{
move_data
<
SrcData
,
SrcDataPer
Access
,
SrcDataPer
Read
,
SrcAddressSpace
,
AddressSpace
::
vgpr
,
InMemoryDataOperation
::
none
>
(
...
...
@@ -441,13 +442,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
dst_coord
.
GetOffset
()
-
dst_nonlinear_coord
.
GetOffset
();
#endif
// Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation
// Check dst data's valid mapping situation, only check the first data in this
// dst
// vector. It's user's responsiblity to make sure all data in the dst vector
// has the valid/invalid mapping situation
if
(
dst_coord
.
IsUpperIndexMappedToValidOffset
())
{
move_data
<
DstData
,
DstDataPer
Access
,
DstDataPer
Write
,
AddressSpace
::
vgpr
,
DstAddressSpace
,
DstInMemOp
>
(
p_dst_long_vector
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
19c3c9a8
...
...
@@ -62,17 +62,19 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopy
SubLengths
=
Sequence
<
1
,
4
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopyClusterLengths
=
Sequence
<
8
,
32
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopy
ThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
4
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopy
Thread
ClusterLengths
_GemmK_GemmM
=
Sequence
<
8
,
32
>
;
// Gemm-K, Gemm-M
constexpr
index_t
GemmABlockCopyDataPerAccess
=
4
;
// Gemm-M
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
4
;
// Gemm-M
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
4
;
// Gemm-M
using
GemmBBlockCopy
SubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopy
ThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopy
Thread
ClusterLengths
_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
constexpr
index_t
GemmBBlockCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDataPer
Access
=
1
;
// Gemm-N
constexpr
index_t
GemmCThreadCopyD
stD
ataPer
Write_GemmN1
=
1
;
#endif
constexpr
index_t
GemmM
=
C
*
Y
*
X
;
...
...
@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopySubLengths
,
GemmABlockCopyClusterLengths
,
GemmABlockCopyDataPerAccess
,
GemmBBlockCopySubLengths
,
GemmBBlockCopyClusterLengths
,
GemmBBlockCopyDataPerAccess
,
GemmCThreadCopyDataPerAccess
>
{};
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
19c3c9a8
...
...
@@ -68,45 +68,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopy
SubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopy
ThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopy
Thread
ClusterLengths
_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopyDataPerAccess
=
1
;
// Gemm-M
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopy
SubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopy
ThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockCopy
Thread
ClusterLengths
_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDataPerAccess
=
1
;
// Gemm-N
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopySubLengths
=
Sequence
<
1
,
4
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopyClusterLengths
=
Sequence
<
8
,
32
>
;
// Gemm-K, Gemm-M
constexpr
index_t
GemmABlockCopyDataPerAccess
=
4
;
// Gemm-M
using
GemmBBlockCopySubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
constexpr
index_t
GemmBBlockCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
GemmCThreadCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
...
...
@@ -126,7 +100,6 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
Htilda
=
Ho
+
right_pad_ho
;
constexpr
index_t
Wtilda
=
Wo
+
right_pad_wo
;
constexpr
index_t
GemmK
=
K
*
Ydot
*
Xdot
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
GemmN
=
N
*
Htilda
*
Wtilda
;
...
...
@@ -159,13 +132,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopySubLengths
,
GemmABlockCopyClusterLengths
,
GemmABlockCopyDataPerAccess
,
GemmBBlockCopySubLengths
,
GemmBBlockCopyClusterLengths
,
GemmBBlockCopyDataPerAccess
,
GemmCThreadCopyDataPerAccess
>
{};
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment