Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1e37e838
Commit
1e37e838
authored
Apr 10, 2019
by
Jing Zhang
Browse files
opt global load/store
parents
8f0b9710
71434918
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1253 additions
and
213 deletions
+1253
-213
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+149
-56
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+7
-1
driver/driver.hip.cpp
driver/driver.hip.cpp
+22
-4
script/compile-hip.sh
script/compile-hip.sh
+2
-1
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+4
-2
src/include/amd_inline_asm.hip.hpp
src/include/amd_inline_asm.hip.hpp
+26
-2
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+5
-5
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+163
-3
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+11
-5
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+40
-44
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
+26
-30
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+409
-0
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+335
-0
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
+6
-8
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+24
-45
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+15
-4
src/include/threadwise_nd_tensor_op.hip.hpp
src/include/threadwise_nd_tensor_op.hip.hpp
+9
-3
No files found.
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
View file @
1e37e838
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "device.hpp"
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
InDesc
,
void
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
(
InDesc
,
...
@@ -83,10 +85,10 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -83,10 +85,10 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread =
8
;
constexpr index_t NPerThread =
4
;
constexpr index_t KPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread =
1
;
constexpr index_t WoPerThread =
2
;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
...
@@ -103,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -103,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
...
@@ -143,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -143,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 3x3 58x58
, NKC = 64, 64, 256
// 3x3 58x58
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
...
@@ -164,43 +168,104 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -164,43 +168,104 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
//
3x3 58x58, NKC = 16,256,12
8
//
for 7x7, 38x3
8
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for
7x7, 38x38
// for
3x3, 56x56, v1, Pacal
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
// not used, yet
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 3x3, 56x56
// for 3x3, 56x56, v1r2, Pascal
// for 3x3, 34x34, v1r2, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// for 3x3, 28x28, v1, Pacal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
...
@@ -208,10 +273,29 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -208,10 +273,29 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
...
@@ -289,35 +373,44 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -289,35 +373,44 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
<
GridSize
,
#if 1
BlockSize
,
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
T
,
#elif 1
decltype
(
in_chwn_desc
),
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
decltype
(
wei_cyxk_desc
),
#elif 0
decltype
(
out_khwn_desc
),
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
NPerBlock
,
#endif
KPerBlock
,
<
GridSize
,
CPerBlock
,
BlockSize
,
HoPerBlock
,
T
,
WoPerBlock
,
decltype
(
in_chwn_desc
),
NPerThread
,
decltype
(
wei_cyxk_desc
),
KPerThread
,
decltype
(
out_khwn_desc
),
HoPerThread
,
NPerBlock
,
WoPerThread
,
KPerBlock
,
GemmMPerThreadSubC
,
CPerBlock
,
GemmNPerThreadSubC
,
HoPerBlock
,
GemmMLevel0Cluster
,
WoPerBlock
,
GemmNLevel0Cluster
,
NPerThread
,
GemmMLevel1Cluster
,
KPerThread
,
GemmNLevel1Cluster
,
HoPerThread
,
GemmKPerThreadLoop
,
WoPerThread
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
GemmMPerThreadSubC
,
InBlockCopy_ThreadPerDimH
,
GemmNPerThreadSubC
,
InBlockCopy_ThreadPerDimW
,
GemmMLevel0Cluster
,
InBlockCopy_ThreadPerDimN
>
,
GemmNLevel0Cluster
,
InBlockCopyDataPerRead
,
GemmMLevel1Cluster
,
WeiBlockCopyDataPerRead
,
GemmNLevel1Cluster
,
OutThreadCopyDataPerWrite
>
{};
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
,
InBlockCopy_ThreadPerDimN
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
...
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
1e37e838
...
@@ -205,6 +205,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -205,6 +205,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
...
@@ -233,6 +235,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -233,6 +235,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
...
@@ -289,6 +293,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -289,6 +293,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
,
InBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
...
@@ -308,7 +314,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -308,7 +314,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1024
)
*
1024
*
1024
*
1024
)
/
(
time
/
1000
));
(
1e12
)
/
(
time
/
1000
));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
}
...
...
driver/driver.hip.cpp
View file @
1e37e838
...
@@ -427,9 +427,12 @@ int main(int argc, char* argv[])
...
@@ -427,9 +427,12 @@ int main(int argc, char* argv[])
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 3x3, 58x58
// 3x3, 58x58
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
...
@@ -457,7 +460,7 @@ int main(int argc, char* argv[])
...
@@ -457,7 +460,7 @@ int main(int argc, char* argv[])
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
38
;
constexpr
index_t
HI
=
38
;
constexpr
index_t
WI
=
38
;
constexpr
index_t
WI
=
38
;
constexpr
index_t
K
=
64
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
constexpr
index_t
X
=
7
;
...
@@ -508,6 +511,18 @@ int main(int argc, char* argv[])
...
@@ -508,6 +511,18 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 3x3 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 1x1 filter, 28x28 image
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
16
;
constexpr
index_t
N
=
16
;
...
@@ -595,10 +610,10 @@ int main(int argc, char* argv[])
...
@@ -595,10 +610,10 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
// 1x1 filter, 14x14 image, C = 512
// 1x1 filter, 14x14 image, C = 512
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
5
12
;
constexpr
index_t
C
=
12
8
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
5
12
;
constexpr
index_t
K
=
12
8
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -641,6 +656,9 @@ int main(int argc, char* argv[])
...
@@ -641,6 +656,9 @@ int main(int argc, char* argv[])
#if 0
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
#elif 1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
...
...
script/compile-hip.sh
View file @
1e37e838
#!/bin/bash
#!/bin/bash
export
KMDUMPISA
=
1
export
KMDUMPISA
=
1
export
KMDUMPLLVM
=
1
export
KMDUMPLLVM
=
1
export
KMOPTLLC
=
-mattr
=
+enable-ds128
make
-j
driver
make
-j
driver
/opt/rocm/hcc/bin/llvm-objdump
-mcpu
=
gfx906
-source
-line-numbers
driver/dump-gfx906.isabin
>
driver/dump-gfx906.isabin.
isa
/opt/rocm/hcc/bin/llvm-objdump
-mcpu
=
gfx906
-source
-line-numbers
driver/dump-gfx906.isabin
>
driver/dump-gfx906.isabin.
asm
src/include/ConstantTensorDescriptor.hip.hpp
View file @
1e37e838
...
@@ -381,7 +381,8 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
...
@@ -381,7 +381,8 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u}
\n
"
,
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}
\n
"
,
s
,
s
,
desc
.
GetDimension
(),
desc
.
GetDimension
(),
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I0
),
...
@@ -416,7 +417,8 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
...
@@ -416,7 +417,8 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr
auto
I8
=
Number
<
8
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
constexpr
auto
I9
=
Number
<
9
>
{};
constexpr
auto
I9
=
Number
<
9
>
{};
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}
\n
"
,
s
,
s
,
desc
.
GetDimension
(),
desc
.
GetDimension
(),
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I0
),
...
...
src/include/amd_inline_asm.hip.hpp
View file @
1e37e838
...
@@ -382,6 +382,30 @@ inline __device__ void ds_read_b128(data4_t& r, void* lds, index_t offset = 0)
...
@@ -382,6 +382,30 @@ inline __device__ void ds_read_b128(data4_t& r, void* lds, index_t offset = 0)
#endif
#endif
}
}
inline
__device__
void
global_store
(
data4_t
&
r
,
const
void
*
vptr
,
const
void
*
sprt
=
0
)
{
#if !NO_GLB_READ
if
(
sprt
==
0
)
{
asm
volatile
(
"
\n
\
global_store_dwordx4 %0, %1, off
\n
\
"
::
"v"
(
vptr
),
"v"
(
r
));
}
else
{
asm
volatile
(
"
\n
\
global_store_dwordx4 %0, %1, %2
\n
\
"
::
"v"
(
vptr
),
"v"
(
r
),
"s"
(
sprt
));
}
#endif
}
inline
__device__
void
global_load
(
data4_t
&
r
,
inline
__device__
void
global_load
(
data4_t
&
r
,
const
void
*
vptr
,
const
void
*
vptr
,
const
void
*
sprt
=
0
)
const
void
*
sprt
=
0
)
...
@@ -392,8 +416,8 @@ inline __device__ void global_load(data4_t& r,
...
@@ -392,8 +416,8 @@ inline __device__ void global_load(data4_t& r,
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
global_load_dwordx4 %0, %1, off
\n
\
global_load_dwordx4 %0, %1, off
\n
\
"
"
:
"=v"
(
r
)
:
"=v"
(
r
)
:
"v"
(
vptr
));
:
"v"
(
vptr
));
}
}
else
else
{
{
...
...
src/include/blockwise_2d_tensor_op.hip.hpp
View file @
1e37e838
...
@@ -493,6 +493,7 @@ struct Blockwise2dTensorCopy3
...
@@ -493,6 +493,7 @@ struct Blockwise2dTensorCopy3
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
p_clipboard
,
Float
*
p_clipboard
,
const
index_t
voff
=
0
)
const
const
index_t
voff
=
0
)
const
Float
*
__restrict__
p_clipboard
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -520,14 +521,13 @@ struct Blockwise2dTensorCopy3
...
@@ -520,14 +521,13 @@ struct Blockwise2dTensorCopy3
auto
f_copy
=
[
&
](
index_t
iloop
)
{
auto
f_copy
=
[
&
](
index_t
iloop
)
{
#if 1
#if 1
data4_t
*
reg
=
(
data4_t
*
)
&
p_clipboard
[
iloop
*
DataPerRead
];
const
void
*
vptr
=
(
void
*
)(
uintptr_t
)((
mSrcMyThreadOffset
+
voff
)
*
4
);
const
void
*
sprt
=
(
void
*
)
&
p_src
[
iloop
*
src_loop_stride
];
global_load
(
*
reg
,
vptr
,
sprt
);
#else
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
iloop
*
DataPerRead
]))
=
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
iloop
*
DataPerRead
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
+
voff
]));
&
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
+
voff
]));
#else
void
*
sptr
=
(
void
*
)
&
p_src
[
iloop
*
src_loop_stride
];
void
*
vptr
=
(
void
*
)(
size_t
)((
mSrcMyThreadOffset
+
voff
)
*
sizeof
(
Float
));
global_load
(
*
(
vector_t
*
)(
&
p_clipboard
[
iloop
*
DataPerRead
]),
vptr
,
sptr
);
#endif
#endif
};
};
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
1e37e838
...
@@ -576,9 +576,169 @@ struct Blockwise4dTensorCopy3
...
@@ -576,9 +576,169 @@ struct Blockwise4dTensorCopy3
iloop_d2
*
thread_per_d2
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_offset
+
mDstMyThreadOffset
))
=
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_offset
+
*
(
reinterpret_cast
<
const
vector_t
*>
(
mSrcMyThreadOffset
));
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
}
__device__
constexpr
index_t
GetRegisterClipboardSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
*
nloop_d3
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
dst_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
src_offset
]));
}
}
}
}
}
}
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
1e37e838
...
@@ -16,7 +16,9 @@ template <index_t BlockSize,
...
@@ -16,7 +16,9 @@ template <index_t BlockSize,
index_t
MLevel1Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
,
index_t
KPerThreadLoop
,
index_t
BatchPerThread
>
index_t
BatchPerThread
,
index_t
DataPerReadA
,
index_t
DataPerReadB
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
{
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetA
=
0
;
...
@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetA
,
mMyThreadOffsetA
,
a_thread_mtx
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
// copy B-sub to form B
// copy B-sub to form B
...
@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetB
,
mMyThreadOffsetB
,
b_thread_mtx
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
// loop over batch
// loop over batch
...
@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
}
}
...
@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
}
}
}
}
...
...
src/include/blockwise_gemm.hip.hpp
View file @
1e37e838
...
@@ -14,7 +14,9 @@ template <index_t BlockSize,
...
@@ -14,7 +14,9 @@ template <index_t BlockSize,
index_t
NLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
>
index_t
KPerThreadLoop
,
index_t
DataPerReadA
,
index_t
DataPerReadB
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
{
struct
MatrixIndex
struct
MatrixIndex
...
@@ -130,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -130,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
const
FloatB
*
__restrict__
p_b_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
FloatC
*
__restrict__
p_c_thread
)
const
{
{
static_assert
(
is_same
<
FloatA
,
float
>::
value
&&
is_same
<
FloatB
,
float
>::
value
&&
is_same
<
FloatC
,
float
>::
value
,
"Run_asm only deal with float
\n
"
);
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
...
@@ -162,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -162,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>::
value
&&
is_same
<
FloatB
,
float
>::
value
&&
is_same
<
FloatC
,
float
>::
value
,
"Run_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_asm cannot deal with this GEMM shape yet
\n
"
);
"Run_asm cannot deal with this GEMM shape yet
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
FloatA
*
p_a_thread
=
p_thread
;
FloatB
*
p_b_thread
=
p_thread
+
a_thread_mtx
.
GetElementSpace
();
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
reg_b
[
1
]
=
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
reg_a
[
1
]
=
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
#pragma unroll
#pragma unroll
for
(
int
k
_i
=
1
;
k
_i
<
K
;
k_i
++
)
for
(
in
dex_
t
k
=
1
;
k
<
K
;
++
k
)
{
{
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k_i
*
lds_a_block_off
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k_i
*
lds_b_block_off
);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k_i
*
lds_b_block_off
);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k_i
*
lds_a_block_off
);
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
lgkmcnt
(
2
);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
}
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
...
@@ -276,7 +266,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -276,7 +266,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetA
,
mMyThreadOffsetA
,
a_thread_mtx
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
#pragma unroll
#pragma unroll
...
@@ -289,7 +280,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -289,7 +280,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB
,
mMyThreadOffsetB
,
b_thread_mtx
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
// C = A * B
// C = A * B
...
@@ -359,7 +351,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -359,7 +351,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block
+
mMyThreadOffsetA
+
m_repeat
*
MPerLevel1Cluster
,
p_a_block
+
mMyThreadOffsetA
+
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
a_thread_sub_mtx
,
p_a_thread_0
+
m_repeat
*
MPerThreadSubC
,
p_a_thread_0
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
#pragma unroll
#pragma unroll
...
@@ -369,7 +362,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -369,7 +362,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_block
+
mMyThreadOffsetB
+
n_repeat
*
NPerLevel1Cluster
,
p_b_block
+
mMyThreadOffsetB
+
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
b_thread_sub_mtx
,
p_b_thread_0
+
n_repeat
*
NPerThreadSubC
,
p_b_thread_0
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
bool
even_loop
=
true
;
bool
even_loop
=
true
;
...
@@ -394,7 +388,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -394,7 +388,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
m_repeat
*
MPerLevel1Cluster
,
m_repeat
*
MPerLevel1Cluster
,
a_thread_sub_mtx
,
a_thread_sub_mtx
,
p_a_thread_next
+
m_repeat
*
MPerThreadSubC
,
p_a_thread_next
+
m_repeat
*
MPerThreadSubC
,
a_thread_sub_mtx
.
GetLengths
());
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
#pragma unroll
#pragma unroll
...
@@ -406,7 +401,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -406,7 +401,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat
*
NPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
,
b_thread_sub_mtx
,
b_thread_sub_mtx
,
p_b_thread_next
+
n_repeat
*
NPerThreadSubC
,
p_b_thread_next
+
n_repeat
*
NPerThreadSubC
,
b_thread_sub_mtx
.
GetLengths
());
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
// C = A * B
// C = A * B
...
...
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
View file @
1e37e838
...
@@ -30,6 +30,8 @@ template <index_t GridSize,
...
@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t
GemmMLevel1Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyThreadPerDims
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
...
@@ -41,8 +43,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -41,8 +43,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
// be careful of this assertion
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
static_assert
(
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -67,8 +70,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -67,8 +70,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
"wrong! cannot evenly divide work for workgroup "
);
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
...
@@ -95,15 +99,17 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -95,15 +99,17 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
// tensor view of blockwise input and weight in LDS
// tensor view of blockwise input and weight in LDS
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
// tensor view of threadwise output in register
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
...
@@ -147,7 +153,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -147,7 +153,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr
auto
c_kxwn_thread_mtx_desc
=
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I
1
)
>
{});
Number
<
out_khwn_thread_desc
.
GetStride
(
I
0
)
>
{});
const
auto
blockwise_batch_gemm
=
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
...
@@ -166,12 +172,11 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -166,12 +172,11 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
HoPerThread
>
{};
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
constexpr
index_t
wei_block_space
=
...
@@ -186,24 +191,24 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -186,24 +191,24 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_
begin
=
const
Float
*
p_in_global_block_
offset
=
p_in_global
+
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_
begin
=
const
Float
*
p_wei_global_block_
offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_
begin
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
),
p_in_global_block_
offset
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
),
p_wei_global_block_
begin
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
p_wei_global_block_
offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
__syncthreads
())
{
{
// input: global mem to LDS
// input: global mem to LDS
blockwise_in_copy
.
Run
(
p_in_global_block_
begin
,
p_in_block
);
blockwise_in_copy
.
Run
(
p_in_global_block_
offset
,
p_in_block
);
// weight: global mem to LDS
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_
begin
,
p_wei_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_
offset
,
p_wei_block
);
__syncthreads
();
__syncthreads
();
...
@@ -276,17 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
...
@@ -276,17 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
...
...
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
0 → 100644
View file @
1e37e838
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// flattend (2d) tensor view of gridwise weight
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDims
,
InBlockCopyDataPerRead
>
{};
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_khwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
// LDS double buffer
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// preload data into LDS
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
}
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
// load next data
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// tail
{
// even
p_in_global_block_offset
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd
__syncthreads
();
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_block_space
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
}
};
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
0 → 100644
View file @
1e37e838
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// 2d tensor view of gridwise weight
constexpr
auto
wei_ck_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_ck_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
max_align
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_khwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDims
,
InBlockCopyDataPerRead
>
{};
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const
auto
blockwise_wei_copy
=
#if 0 // debug
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ck_global_desc),
decltype(wei_ck_block_desc),
decltype(wei_ck_block_desc.GetLengths())>{};
#else
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ck_global_desc
),
decltype
(
wei_ck_block_desc
),
decltype
(
wei_ck_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_ck_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_khwn_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_khwn_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_ck_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
Float
p_out_thread
[
out_khwn_thread_desc
.
GetElementSpace
()];
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
print_ConstantTensorDescriptor(wei_cyxk_global_desc, "wei_cyxk_global_desc");
print_ConstantTensorDescriptor(wei_ck_global_desc, "wei_ck_global_desc");
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_ck_block_desc, "wei_ck_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_chwn_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
))
{
// input: global mem to LDS
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_wei_block
);
__syncthreads
();
blockwise_batch_gemm
.
Run
(
p_wei_block
,
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
__syncthreads
();
}
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
}
};
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
View file @
1e37e838
...
@@ -26,6 +26,8 @@ template <index_t GridSize,
...
@@ -26,6 +26,8 @@ template <index_t GridSize,
index_t
GemmMLevel1Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim1
,
index_t
InBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim0
,
...
@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
>
{};
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
...
@@ -211,17 +215,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
...
@@ -211,17 +215,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
p_wei_register_clipboard
);
#if 1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block
);
#else
vmcnt
(
0
);
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block
);
#endif
__syncthreads
();
__syncthreads
();
...
...
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
1e37e838
...
@@ -7,7 +7,6 @@
...
@@ -7,7 +7,6 @@
#include "threadwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "blockwise_gemm.hip.hpp"
#include "gridwise_ops.hip.hpp"
// define B = flatten(N, Hi, Wi)
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
...
@@ -28,6 +27,8 @@ template <index_t GridSize,
...
@@ -28,6 +27,8 @@ template <index_t GridSize,
index_t
GemmMLevel1Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim1
,
index_t
InBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim0
,
...
@@ -65,6 +66,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -65,6 +66,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work by 2d: [K, B]
// divide block work by 2d: [K, B]
...
@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
>
{};
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
...
@@ -213,24 +217,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -213,24 +217,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
p_wei_register_clipboard
,
p_wei_register_clipboard
,
p_wei_global_block_voffset
);
p_wei_global_block_voffset
);
#if 0
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
p_wei_block_double
);
#else
global_load_wait_all
();
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block_double
);
#endif
}
}
// register
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()]
=
{
0
}
;
// set threadwise output tensor to 0
// set threadwise output tensor to 0
threadwise_2d_tensor_set_zero
(
out_kb_thread_desc
,
p_out_thread
);
//
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
c_block_data_begin
+=
2
*
CPerBlock
)
...
@@ -280,24 +276,17 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -280,24 +276,17 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 1
#elif 1
blockwise_gemm
.
Run_asm
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
);
p_out_thread
);
}
}
}
}
#if 0
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
p_wei_block_next
);
#else
global_load_wait_all
();
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block_next
);
#endif
}
}
}
}
...
@@ -331,25 +320,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -331,25 +320,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 1
#elif 1
blockwise_gemm
.
Run_asm
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
y
*
Wi
+
x
,
p_in_block_double
+
y
*
Wi
+
x
,
p_out_thread
);
p_out_thread
);
}
}
}
}
#if 0
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
p_wei_block_double
+
wei_block_space
);
#else
global_load_wait_all
();
blockwise_in_copy
.
RunStoreRegisterClipboard_asm
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard_asm
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
#endif
// odd
// odd
__syncthreads
();
__syncthreads
();
...
@@ -365,10 +345,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -365,10 +345,10 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 1
#elif 1
blockwise_gemm
.
Run_asm
blockwise_gemm
.
Run_asm
#endif
#endif
(
p_wei_block_double
+
wei_block_space
+
(
p_wei_block_double
+
wei_block_space
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_block_space
+
y
*
Wi
+
x
,
p_in_block_double
+
in_block_space
+
y
*
Wi
+
x
,
p_out_thread
);
p_out_thread
);
}
}
}
}
}
}
...
@@ -380,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -380,9 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
#if 1
if
(
Y
==
1
&&
X
==
1
)
if
(
Y
==
1
&&
X
==
1
)
{
// pure 1x1 conv
{
// pure 1x1 conv
(non padding, 1x1 stride)
constexpr
index_t
K2_
=
GemmMPerThreadSubC
;
constexpr
index_t
K2_
=
GemmMPerThreadSubC
;
constexpr
index_t
K1_
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1_
=
KPerBlock
/
KPerThread
;
constexpr
index_t
B2_
=
GemmNPerThreadSubC
;
constexpr
index_t
B2_
=
GemmNPerThreadSubC
;
...
@@ -400,13 +379,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -400,13 +379,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
out_6d_thread_desc
,
out_6d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_6d_global_desc
,
out_6d_global_desc
,
p_out_global
+
p_out_global
,
out_kb_global_desc
.
Get1dIndex
(
k_thread_data_begin
,
b_thread_data_begin
),
out_6d_thread_desc
.
GetLengths
(),
out_6d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite
>
{},
out_kb_global_desc
.
Get1dIndex
(
k_thread_data_begin
,
b_thread_data_begin
)
);
}
}
else
else
#endif
{
{
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
{
...
...
src/include/threadwise_gemm.hip.hpp
View file @
1e37e838
#pragma once
#pragma once
template
<
class
Float
,
class
SrcMatrix
,
class
DstMatrix
,
index_t
NRow
,
index_t
NCol
>
template
<
class
Float
,
class
SrcMatrix
,
class
DstMatrix
,
index_t
NRow
,
index_t
NCol
,
index_t
DataPerRead
>
__device__
void
threadwise_matrix_copy
(
SrcMatrix
,
__device__
void
threadwise_matrix_copy
(
SrcMatrix
,
const
Float
*
__restrict__
p_src
,
const
Float
*
__restrict__
p_src
,
DstMatrix
,
DstMatrix
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
Sequence
<
NRow
,
NCol
>
)
Sequence
<
NRow
,
NCol
>
,
Number
<
DataPerRead
>
)
{
{
static_assert
(
NCol
%
DataPerRead
==
0
,
"wrong! should be NCol % == DataPerRead == 0"
);
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
src_mtx
=
SrcMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
constexpr
auto
dst_mtx
=
DstMatrix
{};
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
{
{
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
j
);
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
j
);
p_dst
[
dst_index
]
=
p_src
[
src_index
];
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}
}
}
}
...
...
src/include/threadwise_nd_tensor_op.hip.hpp
View file @
1e37e838
...
@@ -8,7 +8,8 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
...
@@ -8,7 +8,8 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
Number
<
DataPerRead
>
)
Number
<
DataPerRead
>
,
index_t
voffset
=
0
)
{
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
...
@@ -60,9 +61,14 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
...
@@ -60,9 +61,14 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
,
did4
,
iloop_d5
*
DataPerRead
);
did0
,
did1
,
did2
,
did3
,
did4
,
iloop_d5
*
DataPerRead
);
#if 1
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
void
*
sptr
=
p_dst
+
dst_index
;
void
*
vptr
=
(
void
*
)(
size_t
)(
voffset
*
sizeof
(
Float
));
global_store
(
*
(
vector_t
*
)(
p_src
+
src_index
),
vptr
,
sptr
);
#else
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
+
voffset
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
#endif
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment