Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
569ad66e
Commit
569ad66e
authored
Apr 23, 2019
by
Chao Liu
Browse files
added implicit gemm v1r3 lds_double_buffer NCHW * CYXK = KNHW, reworked static functionals
parent
87d8740b
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2080 additions
and
857 deletions
+2080
-857
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+77
-54
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
+82
-6
driver/driver.hip.cpp
driver/driver.hip.cpp
+15
-31
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+15
-1
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+23
-65
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+91
-4
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+13
-6
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+17
-3
src/include/common.hip.hpp
src/include/common.hip.hpp
+1
-0
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+64
-2
src/include/functional2.hip.hpp
src/include/functional2.hip.hpp
+117
-0
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
+193
-131
src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+0
-407
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+144
-51
src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
+40
-14
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
+132
-43
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+130
-39
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
+472
-0
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
+452
-0
src/include/threadwise_2d_tensor_op.hip.hpp
src/include/threadwise_2d_tensor_op.hip.hpp
+2
-0
No files found.
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
569ad66e
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device.hpp"
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
...
@@ -81,6 +80,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -81,6 +80,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
#if 0
#if 0
// for 3x3, 34x34, v1r1, Pascal
// for 3x3, 34x34, v1r1, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t CPerBlock = 4;
...
@@ -92,14 +93,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -92,14 +93,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmMLevel0Cluster = 4;
...
@@ -110,11 +103,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -110,11 +103,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t BlockSize = 128;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif
0
#elif
0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
...
@@ -126,27 +124,86 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -126,27 +124,86 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
0
,
0
,
0
,
0
>
;
// not used
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif 1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif 0
// for 3x3, 34x34, v1r3, Pascal, bad
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
1
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
2
,
2
,
32
,
1
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
1
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
1
;
#elif 0
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
// for 3x3, 34x34, v1r1, Vega 20
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
...
@@ -309,38 +366,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -309,38 +366,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopyClusterLengths_CHWN
=
Sequence
<
8
,
2
,
2
,
4
>
;
constexpr
index_t
InBlockCopyDataPerRead_N
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif 0
#elif 0
// for 1x1, 28x28, v1r1, Pascal
// for 1x1, 28x28, v1r1, Pascal
...
@@ -419,13 +444,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -419,13 +444,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if 0
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif
0
GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
#elif
0
#elif
0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
#endif
#endif
<
GridSize
,
<
GridSize
,
...
...
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
View file @
569ad66e
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "device.hpp"
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
(
InDesc
,
void
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
(
InDesc
,
...
@@ -62,7 +64,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -62,7 +64,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if
1
#if
0
// for 3x3, 28x28, v1r2, Pascal
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t BlockSize = 128;
...
@@ -93,8 +95,78 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -93,8 +95,78 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using
WeiBlockCopyClusterLengths_CXK
=
Sequence
<
4
,
1
,
32
>
;
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
constexpr
index_t
WeiBlockCopyDataPerRead_C
=
4
;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif
0
// for 3x3, 28x28, v1r3, Pascal, bad
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
4
,
8
,
2
,
2
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load input for NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
1
;
// not used yet
using
WeiBlockCopyClusterLengths
=
Sequence
<
0
,
0
>
;
// not used
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif 1
// for 3x3, 34x34, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
16
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
2
,
1
,
2
,
1
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
1
,
8
,
1
,
16
>
;
using
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
=
Sequence
<
1
,
2
,
0
,
3
>
;
constexpr
index_t
InBlockReorderDataPerRead_W
=
1
;
// v1r3 cannot do vector load input for NCHW
constexpr
index_t
InBlockReorderDataPerWrite_N
=
1
;
// not used yet
using
WeiBlockCopyClusterLengths
=
Sequence
<
0
,
0
>
;
// not used
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#endif
#endif
...
@@ -108,8 +180,12 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -108,8 +180,12 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if
1
#if
0
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
#elif
1
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
#endif
#endif
<
GridSize
,
<
GridSize
,
BlockSize
,
BlockSize
,
...
@@ -140,8 +216,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -140,8 +216,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
,
InBlockReorderDataPerWrite_N
,
WeiBlockCopyClusterLengths
_CXK
,
WeiBlockCopyClusterLengths
,
WeiBlockCopyDataPerRead_
C
,
WeiBlockCopyDataPerRead_
K
,
OutThreadCopyDataPerWrite_N
>
{};
OutThreadCopyDataPerWrite_N
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
...
...
driver/driver.hip.cpp
View file @
569ad66e
...
@@ -46,7 +46,7 @@ struct GeneratorTensor_3
...
@@ -46,7 +46,7 @@ struct GeneratorTensor_3
#if 0
#if 0
auto f_acc = std::plus<index_t>{};
auto f_acc = std::plus<index_t>{};
#else
#else
auto
f_acc
=
[](
auto
a
,
auto
b
)
{
return
10
*
a
+
b
;
};
auto
f_acc
=
[](
auto
a
,
auto
b
)
{
return
10
0
*
a
+
b
;
};
#endif
#endif
return
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
index_t
(
0
),
f_acc
);
return
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
index_t
(
0
),
f_acc
);
...
@@ -390,8 +390,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
...
@@ -390,8 +390,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
template
<
class
T
>
template
<
class
T
>
void
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
void
check_error
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
{
// printf("\n");
float
error
=
0
;
float
error
=
0
;
float
max_diff
=
-
1
;
float
max_diff
=
-
1
;
float
ref_value
=
0
,
result_value
=
0
;
float
ref_value
=
0
,
result_value
=
0
;
...
@@ -405,10 +403,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -405,10 +403,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
ref_value
=
ref
.
mData
[
i
];
ref_value
=
ref
.
mData
[
i
];
result_value
=
result
.
mData
[
i
];
result_value
=
result
.
mData
[
i
];
}
}
// printf("{%f, %f}", double(ref.mData[i]), double(result.mData[i]));
}
}
// printf("\n");
std
::
cout
<<
"error: "
<<
error
<<
std
::
endl
;
std
::
cout
<<
"error: "
<<
error
<<
std
::
endl
;
std
::
cout
<<
"max_diff: "
<<
max_diff
<<
", "
<<
ref_value
<<
", "
<<
result_value
<<
std
::
endl
;
std
::
cout
<<
"max_diff: "
<<
max_diff
<<
", "
<<
ref_value
<<
", "
<<
result_value
<<
std
::
endl
;
...
@@ -416,11 +411,12 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -416,11 +411,12 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
#if 0
#if 1
constexpr index_t N = 128;
// 3x3, 34x34
constexpr index_t C = 8;
constexpr
index_t
N
=
64
;
constexpr index_t HI = 28;
constexpr
index_t
C
=
256
;
constexpr index_t WI = 28;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -428,27 +424,15 @@ int main(int argc, char* argv[])
...
@@ -428,27 +424,15 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 3x3,
34x34
// 3x3,
56x56
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
25
6
;
constexpr
index_t
C
=
6
4
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3, 56x56
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
...
@@ -499,7 +483,7 @@ int main(int argc, char* argv[])
...
@@ -499,7 +483,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif
1
#elif
0
// 5x5 filter, 20x86 image
// 5x5 filter, 20x86 image
constexpr
index_t
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -547,7 +531,7 @@ int main(int argc, char* argv[])
...
@@ -547,7 +531,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
0
#elif 0
// 1x1 filter, 14x14 image
// 1x1 filter, 14x14 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
...
@@ -619,9 +603,9 @@ int main(int argc, char* argv[])
...
@@ -619,9 +603,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
...
...
src/include/Array.hip.hpp
View file @
569ad66e
...
@@ -19,6 +19,20 @@ struct Array
...
@@ -19,6 +19,20 @@ struct Array
__host__
__device__
const
TData
&
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
const
TData
&
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
TData
&
operator
[](
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
TData
&
operator
[](
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
auto
PushBack
(
TData
x
)
const
{
Array
<
TData
,
NSize
+
1
>
new_array
;
static_for
<
0
,
NSize
,
1
>
{}([
=
](
auto
I
)
{
constexpr
index_t
i
=
I
.
Get
();
new_array
[
i
]
=
mData
[
i
];
});
new_array
[
NSize
]
=
x
;
return
new_array
;
}
};
};
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
...
@@ -51,4 +65,4 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
...
@@ -51,4 +65,4 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
});
});
return
new_array
;
return
new_array
;
}
}
\ No newline at end of file
src/include/ConstantTensorDescriptor.hip.hpp
View file @
569ad66e
#pragma once
#pragma once
#include "common.hip.hpp"
#include "common.hip.hpp"
// this is ugly, only for 2d
template
<
class
PreviousStrides
,
class
RemainLengths
>
template
<
index_t
L0
,
index_t
L1
>
__host__
__device__
constexpr
auto
calculate_default_strides_impl
(
PreviousStrides
,
RemainLengths
)
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
>
)
{
{
return
Sequence
<
L1
,
1
>
{};
constexpr
index_t
previous_stride
=
PreviousStrides
{}.
Front
();
}
constexpr
index_t
current_length
=
RemainLengths
{}.
Back
();
constexpr
index_t
current_stride
=
current_length
*
previous_stride
;
// this is ugly, only for 3d
return
calculate_default_strides_impl
(
PreviousStrides
{}.
PushFront
(
Number
<
current_stride
>
{}),
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
>
RemainLengths
{}.
PopBack
());
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
>
)
{
return
Sequence
<
L1
*
L2
,
L2
,
1
>
{};
}
}
// this is ugly, only for 4d
template
<
class
PreviousStrides
,
index_t
L0
,
index_t
L1
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
>
__host__
__device__
constexpr
auto
calculate_default_strides_impl
(
PreviousStrides
,
Sequence
<
L0
,
L1
>
)
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
)
{
{
return
Sequence
<
L1
*
L2
*
L3
,
L2
*
L3
,
L3
,
1
>
{}
;
constexpr
index_t
previous_stride
=
PreviousStrides
{}.
Front
()
;
}
constexpr
index_t
current_stride
=
L1
*
previous_stride
;
// this is ugly, only for 6d
return
PreviousStrides
{}.
PushFront
(
Number
<
current_stride
>
{});
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
L4
,
index_t
L5
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
>
)
{
return
Sequence
<
L1
*
L2
*
L3
*
L4
*
L5
,
L2
*
L3
*
L4
*
L5
,
L3
*
L4
*
L5
,
L4
*
L5
,
L5
,
1
>
{};
}
}
// this is ugly, only for 8d
template
<
class
Lengths
>
template
<
index_t
L0
,
__host__
__device__
constexpr
auto
calculate_default_strides
(
Lengths
)
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
L4
,
index_t
L5
,
index_t
L6
,
index_t
L7
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
,
L6
,
L7
>
)
{
return
Sequence
<
L1
*
L2
*
L3
*
L4
*
L5
*
L6
*
L7
,
L2
*
L3
*
L4
*
L5
*
L6
*
L7
,
L3
*
L4
*
L5
*
L6
*
L7
,
L4
*
L5
*
L6
*
L7
,
L5
*
L6
*
L7
,
L6
*
L7
,
L7
,
1
>
{};
}
// this is ugly, only for 8d
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
L4
,
index_t
L5
,
index_t
L6
,
index_t
L7
,
index_t
L8
,
index_t
L9
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
,
L6
,
L7
,
L8
,
L9
>
)
{
{
return
Sequence
<
L1
*
L2
*
L3
*
L4
*
L5
*
L6
*
L7
*
L8
*
L9
,
return
calculate_default_strides_impl
(
Sequence
<
1
>
{},
Lengths
{});
L2
*
L3
*
L4
*
L5
*
L6
*
L7
*
L8
*
L9
,
L3
*
L4
*
L5
*
L6
*
L7
*
L8
*
L9
,
L4
*
L5
*
L6
*
L7
*
L8
*
L9
,
L5
*
L6
*
L7
*
L8
*
L9
,
L6
*
L7
*
L8
*
L9
,
L7
*
L8
*
L9
,
L8
*
L9
,
L9
,
1
>
{};
}
}
// this is ugly, only for 2d
// this is ugly, only for 2d
...
@@ -186,6 +136,14 @@ struct ConstantTensorDescriptor
...
@@ -186,6 +136,14 @@ struct ConstantTensorDescriptor
return
Get1dIndex
(
multi_id
);
return
Get1dIndex
(
multi_id
);
}
}
template
<
index_t
...
Is
>
__host__
__device__
static
constexpr
index_t
Get1dIndex
(
Sequence
<
Is
...
>
multi_id
)
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"wrong! Dimension not consistent"
);
return
Get1dIndex
(
Is
...);
}
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
{
{
Array
<
index_t
,
nDim
>
multi_id
;
Array
<
index_t
,
nDim
>
multi_id
;
...
...
src/include/Sequence.hip.hpp
View file @
569ad66e
...
@@ -34,11 +34,15 @@ struct Sequence
...
@@ -34,11 +34,15 @@ struct Sequence
template
<
index_t
...
IRs
>
template
<
index_t
...
IRs
>
__host__
__device__
constexpr
auto
ReorderGivenOld2New
(
Sequence
<
IRs
...
>
/*old2new*/
)
const
__host__
__device__
constexpr
auto
ReorderGivenOld2New
(
Sequence
<
IRs
...
>
/*old2new*/
)
const
{
{
// don't know how to implement this
//
TODO:
don't know how to implement this
printf
(
"Sequence::ReorderGivenOld2New not implemented"
);
printf
(
"Sequence::ReorderGivenOld2New not implemented"
);
assert
(
false
);
assert
(
false
);
}
}
__host__
__device__
constexpr
index_t
Front
()
const
{
return
mData
[
0
];
}
__host__
__device__
constexpr
index_t
Back
()
const
{
return
mData
[
mSize
-
1
];
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushFront
(
Number
<
I
>
)
const
__host__
__device__
constexpr
auto
PushFront
(
Number
<
I
>
)
const
{
{
...
@@ -69,15 +73,98 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
...
@@ -69,15 +73,98 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return
Sequence
<
Is
...
>
{};
return
Sequence
<
Is
...
>
{};
}
}
template
<
index_t
...
Is
,
index_t
I
>
#if 0
// TODO: for some reason, compiler cannot instantiate this template
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
{
static_assert(sizeof...(Is) > 0, "empty Sequence!");
static_assert(sizeof...(Is) > 0, "empty Sequence!");
return Sequence<Is...>{};
return Sequence<Is...>{};
}
}
#else
// TODO: delete these very ugly mess
template
<
index_t
I0
,
index_t
I1
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
>
)
{
return
Sequence
<
I0
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
>
)
{
return
Sequence
<
I0
,
I1
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
{
return
Sequence
<
I0
,
I1
,
I2
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
,
index_t
I7
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
,
index_t
I7
,
index_t
I8
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
,
I8
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
,
index_t
I7
,
index_t
I8
,
index_t
I9
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
,
I8
,
I9
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
,
I8
>
{};
}
#endif
#if 1
#if 1
//
this is ugly, only for 2 sequenc
es
//
TODO: fix these m
es
s
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
...
@@ -86,7 +173,6 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
...
@@ -86,7 +173,6 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
}
// this is ugly, only for 3 sequences
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
...
@@ -98,6 +184,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
...
@@ -98,6 +184,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
}
#else
#else
// TODO:: these doesn't compile
template
<
index_t
NRemain
>
template
<
index_t
NRemain
>
struct
transform_sequences_impl
struct
transform_sequences_impl
{
{
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
569ad66e
#pragma once
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
__device__
void
...
@@ -957,6 +958,7 @@ struct Blockwise4dTensorCopyReorder3
...
@@ -957,6 +958,7 @@ struct Blockwise4dTensorCopyReorder3
constexpr
auto
thread_sub_tensor_desc
=
constexpr
auto
thread_sub_tensor_desc
=
make_ConstantTensorDescriptor
(
SrcClusterLengths
{},
thread_tensor_desc
.
GetStrides
());
make_ConstantTensorDescriptor
(
SrcClusterLengths
{},
thread_tensor_desc
.
GetStrides
());
#if 1
for
(
index_t
icluster_d0
=
0
;
icluster_d0
<
cluster_per_dims
.
Get
(
I0
);
++
icluster_d0
)
for
(
index_t
icluster_d0
=
0
;
icluster_d0
<
cluster_per_dims
.
Get
(
I0
);
++
icluster_d0
)
{
{
for
(
index_t
icluster_d1
=
0
;
icluster_d1
<
cluster_per_dims
.
Get
(
I1
);
++
icluster_d1
)
for
(
index_t
icluster_d1
=
0
;
icluster_d1
<
cluster_per_dims
.
Get
(
I1
);
++
icluster_d1
)
...
@@ -978,16 +980,21 @@ struct Blockwise4dTensorCopyReorder3
...
@@ -978,16 +980,21 @@ struct Blockwise4dTensorCopyReorder3
icluster_d2
*
thread_sub_tensor_lengths
.
Get
(
I2
),
icluster_d2
*
thread_sub_tensor_lengths
.
Get
(
I2
),
icluster_d3
*
thread_sub_tensor_lengths
.
Get
(
I3
));
icluster_d3
*
thread_sub_tensor_lengths
.
Get
(
I3
));
threadwise_
4
d_tensor_copy
_v2
(
SrcDesc
{},
threadwise_
n
d_tensor_copy
(
SrcDesc
{},
p_src
+
src_offset
+
mSrcMyThreadOffset
,
p_src
+
src_offset
+
mSrcMyThreadOffset
,
thread_tensor_desc
,
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
thread_sub_tensor_lengths
,
thread_sub_tensor_lengths
,
Number
<
SrcDataPerRead
>
{});
Number
<
SrcDataPerRead
>
{});
}
}
}
}
}
}
}
}
#else
static_ford
<
decltype
(
cluster_per_dims
)
>
{}([
=
](
auto
cluster_ids
)
{
});
#endif
#if 0
#if 0
if(get_block_1d_id() == 0)
if(get_block_1d_id() == 0)
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
569ad66e
...
@@ -253,9 +253,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -253,9 +253,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
p_a_thread[0], p_a_thread[1], p_a_thread[2], p_a_thread[3], p_a_thread[4], p_a_thread[5], p_a_thread[6], p_a_thread[7],
p_a_thread[0],
p_b_thread[0], p_b_thread[1], p_b_thread[2], p_b_thread[3], p_b_thread[4], p_b_thread[5], p_b_thread[6], p_b_thread[7]);
p_a_thread[1],
p_a_thread[2],
p_a_thread[3],
p_a_thread[4],
p_a_thread[5],
p_a_thread[6],
p_a_thread[7],
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3],
p_b_thread[4],
p_b_thread[5],
p_b_thread[6],
p_b_thread[7]);
}
}
#endif
#endif
...
...
src/include/common.hip.hpp
View file @
569ad66e
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "Sequence.hip.hpp"
#include "Sequence.hip.hpp"
#include "Array.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
#include "functional2.hip.hpp"
#if DEVICE_BACKEND_HIP
#if DEVICE_BACKEND_HIP
#include "amd_inline_asm.hip.hpp"
#include "amd_inline_asm.hip.hpp"
...
...
src/include/functional.hip.hpp
View file @
569ad66e
...
@@ -21,7 +21,7 @@ struct static_for_impl<Iter, 0, Increment>
...
@@ -21,7 +21,7 @@ struct static_for_impl<Iter, 0, Increment>
template
<
class
F
>
template
<
class
F
>
__host__
__device__
void
operator
()(
F
)
const
__host__
__device__
void
operator
()(
F
)
const
{
{
//
d
o
nothing
//
n
o
work left, just return
return
;
return
;
}
}
};
};
...
@@ -48,7 +48,7 @@ struct static_const_reduce_n
...
@@ -48,7 +48,7 @@ struct static_const_reduce_n
static_assert
(
NLoop
>
1
,
"out-of-range"
);
static_assert
(
NLoop
>
1
,
"out-of-range"
);
constexpr
auto
a
=
f
(
Number
<
NLoop
-
1
>
{});
constexpr
auto
a
=
f
(
Number
<
NLoop
-
1
>
{});
auto
b
=
static_const_reduce_n
<
NLoop
-
1
>
{}(
f
,
r
);
// cannot use constexpr here, weird
auto
b
=
static_const_reduce_n
<
NLoop
-
1
>
{}(
f
,
r
);
//
TODO:
cannot use constexpr here, weird
return
r
(
a
,
b
);
return
r
(
a
,
b
);
}
}
};
};
...
@@ -70,3 +70,65 @@ __host__ __device__ constexpr auto unpacker(F f)
...
@@ -70,3 +70,65 @@ __host__ __device__ constexpr auto unpacker(F f)
return [=](auto xs_array){ f(xs...); };
return [=](auto xs_array){ f(xs...); };
}
}
#endif
#endif
struct
forwarder
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
&&
x
)
const
{
return
std
::
forward
<
T
>
(
x
);
}
};
// Emulate compile time if statement for C++14
// Get the idea from
// "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html"
// TODO: use if constexpr, when C++17 is supported
template
<
bool
Predicate
>
struct
static_if
{
};
template
<
>
struct
static_if
<
true
>
{
using
Type
=
static_if
<
true
>
;
template
<
class
F
>
__host__
__device__
constexpr
auto
operator
()(
F
f
)
const
{
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// this will make "f" a generic lambda, so that "f" won't be compiled until here
f
(
forwarder
{});
return
Type
{};
}
template
<
class
F
>
__host__
__device__
static
constexpr
auto
else_
(
F
)
{
return
Type
{};
}
};
template
<
>
struct
static_if
<
false
>
{
using
Type
=
static_if
<
false
>
;
template
<
class
F
>
__host__
__device__
constexpr
auto
operator
()(
F
)
const
{
return
Type
{};
}
template
<
class
F
>
__host__
__device__
static
constexpr
auto
else_
(
F
f
)
{
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// this will make "f" a generic lambda, so that "f" won't be compiled until here
f
(
forwarder
{});
return
Type
{};
}
};
src/include/functional2.hip.hpp
0 → 100644
View file @
569ad66e
#pragma once
#include "Sequence.hip.hpp"
template
<
index_t
RemainDim
>
struct
static_ford_impl
{
// F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...>
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
,
RemainLengths
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
==
RemainDim
,
"wrong!"
);
static_assert
(
RemainDim
>
1
,
"wrong!"
);
constexpr
auto
next_length
=
RemainLengths
{}.
Front
();
static_for
<
0
,
next_length
,
1
>
{}([
=
](
auto
I
)
{
static_ford_impl
<
RemainDim
-
1
>
{}(
f
,
CurrentMultiIndex
{}.
PushBack
(
I
),
RemainLengths
{}.
PopFront
());
});
}
};
template
<
>
struct
static_ford_impl
<
1
>
{
// F signature: F(Sequence<Is...> multi_id)
// CurrentMultiIndex: Sequence<...>
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
,
RemainLengths
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
==
1
,
"wrong!"
);
constexpr
index_t
last_length
=
RemainLengths
{}.
Front
();
static_for
<
0
,
last_length
,
1
>
{}([
=
](
auto
I
)
{
f
(
CurrentMultiIndex
{}.
PushBack
(
I
));
});
}
};
// Lengths is Sequence<...>
template
<
class
Lengths
>
struct
static_ford
{
// F signature: F(Sequence<Is...> multi_id)
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
{
constexpr
index_t
first_length
=
Lengths
{}.
Front
();
static_for
<
0
,
first_length
,
1
>
{}([
=
](
auto
I
)
{
static_ford_impl
<
Lengths
::
GetSize
()
-
1
>
{}(
f
,
Sequence
<
I
.
Get
()
>
{},
Lengths
{}.
PopFront
());
});
}
};
template
<
index_t
RemainDim
>
struct
ford_impl
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
current_multi_id
,
RemainLengths
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
==
RemainDim
,
"wrong!"
);
static_assert
(
RemainDim
>
1
,
"wrong!"
);
constexpr
auto
next_length
=
RemainLengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
next_length
;
++
i
)
{
ford_impl
<
RemainDim
-
1
>
{}(
f
,
current_multi_id
.
PushBack
(
i
),
RemainLengths
{}.
PopFront
());
}
}
};
template
<
>
struct
ford_impl
<
1
>
{
// F signature: F(Array<...> multi_id)
// CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...>
template
<
class
F
,
class
CurrentMultiIndex
,
class
RemainLengths
>
__host__
__device__
void
operator
()(
F
f
,
CurrentMultiIndex
current_multi_id
,
RemainLengths
)
const
{
static_assert
(
RemainLengths
::
GetSize
()
==
1
,
"wrong!"
);
constexpr
index_t
last_length
=
RemainLengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
last_length
;
++
i
)
{
f
(
current_multi_id
.
PushBack
(
i
));
}
}
};
// Lengths is Sequence<...>
template
<
class
Lengths
>
struct
ford
{
// F signature: F(Array<...> multi_id)
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
{
constexpr
index_t
first_length
=
Lengths
{}.
Front
();
for
(
index_t
i
=
0
;
i
<
first_length
;
++
i
)
{
ford_impl
<
Lengths
::
GetSize
()
-
1
>
{}(
f
,
Array
<
index_t
,
1
>
{
i
},
Lengths
{}.
PopFront
());
}
}
};
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
View file @
569ad66e
This diff is collapsed.
Click to expand it.
src/include/gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
deleted
100644 → 0
View file @
87d8740b
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// flattend (2d) tensor view of gridwise weight
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDims
,
InBlockCopyDataPerRead
>
{};
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_khwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
// LDS double buffer
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// preload data into LDS
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
}
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
// load next data
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// tail
{
// even
p_in_global_block_offset
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd
__syncthreads
();
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_block_space
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
}
};
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
View file @
569ad66e
...
@@ -33,10 +33,10 @@ template <index_t GridSize,
...
@@ -33,10 +33,10 @@ template <index_t GridSize,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
GemmDataPerReadB
,
class
InBlockCopy
ThreadPerDims
,
class
InBlockCopy
ClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
_N
,
index_t
WeiBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
_K
,
index_t
OutThreadCopyDataPerWrite
>
index_t
OutThreadCopyDataPerWrite
_N
>
struct
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
struct
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
...
@@ -44,9 +44,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -44,9 +44,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
// be careful of this assertion
// be careful of this assertion
static_assert
(
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
(
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -101,14 +103,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -101,14 +103,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead_N
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_x_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_x_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
X
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
...
@@ -116,14 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -116,14 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
#if
0
#if
1
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead>{};
InBlockCopyDataPerRead
_N
>
{};
#else
#else
const
auto
blockwise_in_copy
=
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Blockwise4dTensorCopy3
<
BlockSize
,
...
@@ -131,8 +142,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -131,8 +142,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopy
ThreadPerDims
,
InBlockCopy
ClusterLengths_CHWN
,
InBlockCopyDataPerRead
>
{};
InBlockCopyDataPerRead
_N
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
...
@@ -143,7 +154,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -143,7 +154,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
decltype
(
wei_c_x_k_global_desc
),
decltype
(
wei_c_x_k_global_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
WeiBlockCopyDataPerRead
_K
>
{};
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -195,7 +206,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -195,7 +206,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
__shared__
Float
p_wei_block
[
wei_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// register
Float
p_out_thread
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
@@ -293,46 +306,126 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -293,46 +306,126 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}(
// output is a 10d tensor
[
&
](
auto
f_dummy
)
{
// f_dummy do nothing but perfect forwarding. Using this trick to
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
// make this lambda a generic lambda, so it won't be compiled until
constexpr
index_t
N1
=
NPerBlock
/
N2
;
// instantiated
static_assert
((
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
constexpr
index_t
W2
=
#if 0
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
"out_k_h_w_n_thread_desc")
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc")
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
threadwise_nd_tensor_copy
(
out_10d_thread_desc
,
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
})
.
else_
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
}
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
threadwise_nd_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_10d_global_desc
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
p_out_global
+
k_block_data_begin
+
k_thread_data_begin
,
out_k_h_w_n_global_desc
.
Get1dIndex
(
ho_block_data_begin
+
ho_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
,
out_10d_thread_desc
.
GetLengths
(),
n_block_data_begin
+
n_thread_data_begin
),
Number
<
OutThreadCopyDataPerWrite
>
{});
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
View file @
569ad66e
...
@@ -39,7 +39,7 @@ template <index_t GridSize,
...
@@ -39,7 +39,7 @@ template <index_t GridSize,
index_t
InBlockReorderDataPerRead_W
,
index_t
InBlockReorderDataPerRead_W
,
index_t
InBlockReorderDataPerWrite_N
,
index_t
InBlockReorderDataPerWrite_N
,
class
WeiBlockCopyClusterLengths_CXK
,
class
WeiBlockCopyClusterLengths_CXK
,
index_t
WeiBlockCopyDataPerRead_
C
,
index_t
WeiBlockCopyDataPerRead_
K
,
index_t
OutThreadCopyDataPerWrite_N
>
index_t
OutThreadCopyDataPerWrite_N
>
struct
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
struct
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{
{
...
@@ -106,7 +106,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -106,7 +106,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockReorderDataPerWrite_N
,
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_
C
,
WeiBlockCopyDataPerRead_
K
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
);
GemmDataPerReadB
);
...
@@ -146,7 +146,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -146,7 +146,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
WeiBlockCopyClusterLengths_CXK
,
WeiBlockCopyClusterLengths_CXK
,
WeiBlockCopyDataPerRead_
C
>
{};
WeiBlockCopyDataPerRead_
K
>
{};
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -216,6 +216,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -216,6 +216,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if 0
const Float* p_in_global_block_offset =
const Float* p_in_global_block_offset =
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
...
@@ -229,7 +230,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -229,7 +230,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{
{
for(index_t y = 0; y < Y; ++y)
for(index_t y = 0; y < Y; ++y)
{
{
#if 1
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
p_in_block);
p_in_block);
...
@@ -237,24 +237,49 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -237,24 +237,49 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
blockwise_wei_copy.Run(p_wei_global_block_offset +
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
p_wei_block);
p_wei_block);
__syncthreads();
for(index_t x = 0; x < X; ++x)
{
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
p_in_block +
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
p_out_thread);
}
__syncthreads();
}
}
#else
#else
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
,
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
Float
p_in_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_in_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_global_block_offset
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
0
,
0
,
y
,
0
),
p_in_clipboard
);
p_in_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
p_wei_clipboard
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_clipboard
,
p_wei_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_clipboard
,
p_wei_block
);
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_clipboard
,
p_in_block
);
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_clipboard
,
p_in_block
);
#endif
__syncthreads
();
__syncthreads
();
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
...
@@ -268,6 +293,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -268,6 +293,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
__syncthreads
();
__syncthreads
();
}
}
}
}
#endif
// output: register to global mem,
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp
View file @
569ad66e
...
@@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
// be careful of this assertion
// be careful of this assertion
static_assert
(
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
(
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -66,9 +68,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -66,9 +68,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
...
@@ -106,10 +105,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -106,10 +105,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
...
@@ -177,6 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -177,6 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
GemmDataPerReadB
>
{};
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
// TODO:: need to properly implement tensor descriptor with alignment
constexpr
index_t
in_block_space
=
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
...
@@ -185,7 +192,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -185,7 +192,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
__shared__
Float
p_wei_block
[
wei_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// register
Float
p_out_thread
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
@@ -276,46 +285,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
...
@@ -276,46 +285,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}(
// output is a 10d tensor
[
&
](
auto
f_dummy
)
{
// f_dummy do nothing but perfect forwarding. Using this trick to
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
// make this lambda a generic lambda, so it won't be compiled until
constexpr
index_t
N1
=
NPerBlock
/
N2
;
// instantiated
static_assert
((
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
constexpr
index_t
W2
=
NPerBlock
%
GemmNPerThreadSubC
==
0
),
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
"wrong!"
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
#if 0
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
threadwise_nd_tensor_copy
(
out_10d_thread_desc
,
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
})
.
else_
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
}
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
threadwise_nd_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_10d_global_desc
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
p_out_global
+
k_block_data_begin
+
k_thread_data_begin
,
out_k_h_w_n_global_desc
.
Get1dIndex
(
ho_block_data_begin
+
ho_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
,
out_10d_thread_desc
.
GetLengths
(),
n_block_data_begin
+
n_thread_data_begin
),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
View file @
569ad66e
...
@@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
// be careful of this assertion
// be careful of this assertion
static_assert
(
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
(
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -109,10 +111,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -109,10 +111,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
mod_conv
::
max
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
...
@@ -199,7 +208,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -199,7 +208,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register
// register
Float
p_out_thread
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...
@@ -336,46 +347,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -336,46 +347,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}(
// output is a 10d tensor
[
&
](
auto
f_dummy
)
{
// f_dummy do nothing but perfect forwarding. Using this trick to
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
// make this lambda a generic lambda, so it won't be compiled until
constexpr
index_t
N1
=
NPerBlock
/
N2
;
// instantiated
static_assert
((
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
constexpr
index_t
W2
=
#if 0
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
"out_k_h_w_n_thread_desc")
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc")
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
threadwise_nd_tensor_copy
(
out_10d_thread_desc
,
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
})
.
else_
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
"out_k_h_w_n_global_desc");
}
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
threadwise_nd_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_10d_global_desc
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
p_out_global
+
k_block_data_begin
+
k_thread_data_begin
,
out_k_h_w_n_global_desc
.
Get1dIndex
(
ho_block_data_begin
+
ho_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
,
out_10d_thread_desc
.
GetLengths
(),
n_block_data_begin
+
n_thread_data_begin
),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
0 → 100644
View file @
569ad66e
This diff is collapsed.
Click to expand it.
src/include/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp
0 → 100644
View file @
569ad66e
This diff is collapsed.
Click to expand it.
src/include/threadwise_2d_tensor_op.hip.hpp
View file @
569ad66e
...
@@ -88,6 +88,7 @@ threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
...
@@ -88,6 +88,7 @@ threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
}
#if 0 // replaced threadwise_nd_tensor_copy
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
__device__ void threadwise_2d_tensor_copy(
__device__ void threadwise_2d_tensor_copy(
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
...
@@ -97,6 +98,7 @@ __device__ void threadwise_2d_tensor_copy(
...
@@ -97,6 +98,7 @@ __device__ void threadwise_2d_tensor_copy(
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
}
}
#endif
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_2d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
__device__
void
threadwise_2d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment