Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
d2436a58
Commit
d2436a58
authored
Apr 27, 2019
by
Chao Liu
Browse files
v1r3 nchw*cyxk=nkhw lds double buffer
parent
63cdc6d2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
745 additions
and
174 deletions
+745
-174
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
+3
-3
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
+148
-37
driver/driver.hip.cpp
driver/driver.hip.cpp
+6
-6
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
+15
-17
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
+14
-16
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
+475
-0
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
+84
-95
No files found.
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
View file @
d2436a58
...
...
@@ -95,7 +95,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
constexpr index_t InBlockReorderDataPerWrite_N = 2;
using WeiBlockCopyClusterLengths =
Sequence<0, 0>; // not used
using WeiBlockCopyClusterLengths =
void;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
...
...
@@ -130,7 +130,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load input for NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
2
;
using
WeiBlockCopyClusterLengths
=
Sequence
<
0
,
0
>
;
// not used
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
...
...
@@ -200,7 +200,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load input for NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
1
;
using
WeiBlockCopyClusterLengths
=
Sequence
<
0
,
0
>
;
// not used
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
...
...
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
View file @
d2436a58
...
...
@@ -3,6 +3,7 @@
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
(
InDesc
,
...
...
@@ -92,7 +93,42 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif
0
// for 3x3, 34x34, v1r3, Vega 20
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
1
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
1
,
2
,
2
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
1
,
4
,
2
,
32
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
1
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
4
;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
2
;
...
...
@@ -125,9 +161,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
2
;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20,
try
// for 3x3, 34x34, v1r3, Vega 20,
WoPerBlock = 8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
4
;
...
...
@@ -160,7 +196,77 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
1
;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
2
,
8
,
4
,
4
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
1
;
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
8
,
8
,
2
,
2
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
using
WeiBlockCopyClusterLengths
=
void
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
1
;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
...
...
@@ -206,39 +312,44 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_nkhw_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockReorderSrcClusterLengths_NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
,
WeiBlockCopyClusterLengths
,
WeiBlockCopyDataPerRead_K
,
OutThreadCopyDataPerWrite_W
>
{};
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#else
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_nkhw_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockReorderSrcClusterLengths_NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
,
WeiBlockCopyClusterLengths
,
WeiBlockCopyDataPerRead_K
,
OutThreadCopyDataPerWrite_W
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
...
...
driver/driver.hip.cpp
View file @
d2436a58
...
...
@@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std
::
size_t
ho
=
HoPerTile
*
htile
+
j
;
for
(
int
i
=
0
;
i
<
WoPerTile
;
++
i
)
{
std
::
size_t
wo
=
WoPerTile
*
wtile
+
i
;
std
::
size_t
wo
=
WoPerTile
*
wtile
+
i
;
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_hold
(
n
,
k
,
htile
,
wtile
,
j
,
i
);
}
}
...
...
@@ -413,13 +413,13 @@ int main(int argc, char* argv[])
{
#if 1
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
View file @
d2436a58
...
...
@@ -74,22 +74,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
constexpr
index_t
NBlockWork
=
mod_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
mod_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndex
(
get_block_1d_id
());
const
index_t
n_block_data_begin
=
block_work_multi_id
[
0
]
*
NPerBlock
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
1
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
2
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
3
]
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
...
...
@@ -185,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
1
#if
0
return blockwise_batch_gemm.Run(Xs...);
#elif
0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
View file @
d2436a58
...
...
@@ -81,22 +81,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
constexpr
index_t
NBlockWork
=
mod_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
mod_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndex
(
get_block_1d_id
());
const
index_t
n_block_data_begin
=
block_work_multi_id
[
0
]
*
NPerBlock
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
1
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
2
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
3
]
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
0 → 100644
View file @
d2436a58
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_nd_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockReorderSrcSubLengths_NCHW
,
class
InBlockReorderSrcClusterLengths_NCHW
,
class
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
index_t
InBlockReorderDataPerRead_W
,
index_t
InBlockReorderDataPerWrite_N
,
class
WeiBlockCopyClusterLengths_CK
,
// not used
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_W
>
struct
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
N
=
out_n_k_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
mod_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
mod_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndex
(
get_block_1d_id
());
const
index_t
n_block_data_begin
=
block_work_multi_id
[
0
]
*
NPerBlock
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
1
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
2
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
3
]
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockReorderDataPerWrite_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr
auto
map_chwn2nchw
=
Sequence
<
1
,
2
,
3
,
0
>
{};
const
auto
blockwise_in_copy_reorder
=
BlockwiseNdTensorCopyReorder_v3
<
BlockSize
,
Float
,
decltype
(
in_n_c_h_w_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
Sequence
<
NPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
>
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockReorderSrcClusterLengths_NCHW
,
decltype
(
map_chwn2nchw
),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
>
{};
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif
0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
#else
return
blockwise_batch_gemm
.
Run_asm_v2
(
Xs
...);
#endif
};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
// LDS double buffer
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
,
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
+
x
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
x
,
k_block_data_begin
);
// LDS double buffer: preload data into LDS
{
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert
(
(
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
}).
else_
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
/
N1
,
N1
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
});
}
};
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hip.hpp
View file @
d2436a58
...
...
@@ -193,7 +193,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
1
#if
0
return blockwise_batch_gemm.Run(Xs...);
#elif
0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
...
@@ -340,39 +340,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}(
[
&
](
auto
f_dummy
)
{
// f_dummy do nothing but perfect forwarding. Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert
((
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert
(
(
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
...
@@ -387,51 +388,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
})
.
else_
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
/
N1
,
N1
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
}).
else_
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
/
N1
,
N1
,
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
...
@@ -447,21 +437,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
});
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
threadwise_nd_tensor_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
});
}
}
;
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment