Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
19f17df4
Commit
19f17df4
authored
Apr 18, 2019
by
Chao Liu
Browse files
implicit gemm v1r2: adding support for nchw
parent
17f3d2d4
Changes
16
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1621 additions
and
217 deletions
+1621
-217
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+2
-2
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
+433
-0
driver/driver.hip.cpp
driver/driver.hip.cpp
+24
-13
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+34
-0
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+15
-7
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+84
-29
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+9
-9
src/include/blockwise_3d_tensor_op.hip.hpp
src/include/blockwise_3d_tensor_op.hip.hpp
+270
-2
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+315
-57
src/include/common.hip.hpp
src/include/common.hip.hpp
+28
-2
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+0
-15
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+1
-37
src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
+362
-0
src/include/threadwise_2d_tensor_op.hip.hpp
src/include/threadwise_2d_tensor_op.hip.hpp
+7
-12
src/include/threadwise_4d_tensor_op.hip.hpp
src/include/threadwise_4d_tensor_op.hip.hpp
+36
-31
src/include/threadwise_nd_tensor_op.hip.hpp
src/include/threadwise_nd_tensor_op.hip.hpp
+1
-1
No files found.
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
19f17df4
...
@@ -243,7 +243,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -243,7 +243,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif
0
#elif
1
// for 3x3, 28x28, v1r1, Pacal
// for 3x3, 28x28, v1r1, Pacal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
...
@@ -386,7 +386,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -386,7 +386,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if
0
#if
1
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0
#elif 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
...
...
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
0 → 100644
View file @
19f17df4
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
ostream_ConstantTensorDescriptor
(
wei_cyxk_desc
,
std
::
cout
<<
"wei_cyxk_desc: "
);
Tensor
<
T
>
wei_cyxk
(
make_TensorDescriptor
(
wei_cyxk_desc
));
auto
f_reorder_kcyx2cyxk
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_reorder_kcyx2cyxk
,
K
,
C
,
Y
,
X
)(
std
::
thread
::
hardware_concurrency
());
// output
auto
out_khwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{});
ostream_ConstantTensorDescriptor
(
out_khwn_desc
,
std
::
cout
<<
"out_khwn_desc: "
);
Tensor
<
T
>
out_khwn
(
make_TensorDescriptor
(
out_khwn_desc
));
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
// for 3x3, 34x34, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif
0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
1
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif 0
// for 3x3, 56x56, v1, Pascal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56, v1r2, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 28x28, v1r1, Pacal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// for 3x3, 28x28, v1r2, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 28x28
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// for 1x1, 14x14, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
constexpr
auto
gridwise_conv
=
#if 1
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
Sequence
<
InBlockCopy_ThreadPerDimN
,
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
// reorder output
auto
f_reorder_khwn2nkhw
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_reorder_khwn2nkhw
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
driver/driver.hip.cpp
View file @
19f17df4
...
@@ -38,9 +38,6 @@ struct GeneratorTensor_2
...
@@ -38,9 +38,6 @@ struct GeneratorTensor_2
struct
GeneratorTensor_3
struct
GeneratorTensor_3
{
{
int
min_value
=
0
;
int
max_value
=
9
;
template
<
class
...
Is
>
template
<
class
...
Is
>
double
operator
()(
Is
...
is
)
double
operator
()(
Is
...
is
)
{
{
...
@@ -420,6 +417,17 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -420,6 +417,17 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
#if 0
#if 0
constexpr index_t N = 128;
constexpr index_t C = 8;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif
0
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -642,6 +650,9 @@ int main(int argc, char* argv[])
...
@@ -642,6 +650,9 @@ int main(int argc, char* argv[])
#if 0
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_3
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
#elif 1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
...
@@ -664,7 +675,7 @@ int main(int argc, char* argv[])
...
@@ -664,7 +675,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif
0
#elif
1
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
...
...
src/include/Array.hip.hpp
View file @
19f17df4
#pragma once
#pragma once
#include "Sequence.hip.hpp"
#include "functional.hip.hpp"
template
<
class
TData
,
index_t
NSize
>
template
<
class
TData
,
index_t
NSize
>
struct
Array
struct
Array
...
@@ -18,3 +20,35 @@ struct Array
...
@@ -18,3 +20,35 @@ struct Array
__host__
__device__
TData
&
operator
[](
index_t
i
)
{
return
mData
[
i
];
}
__host__
__device__
TData
&
operator
[](
index_t
i
)
{
return
mData
[
i
];
}
};
};
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
auto
reorder_array_given_new2old
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
new2old
)
{
Array
<
TData
,
NSize
>
new_array
;
static_assert
(
NSize
==
sizeof
...(
IRs
),
"NSize not consistent"
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
.
Get
();
new_array
[
idim
]
=
old_array
[
new2old
.
Get
(
IDim
)];
});
return
new_array
;
}
template
<
class
TData
,
index_t
NSize
,
index_t
...
IRs
>
__host__
__device__
auto
reorder_array_given_old2new
(
const
Array
<
TData
,
NSize
>&
old_array
,
Sequence
<
IRs
...
>
old2new
)
{
Array
<
TData
,
NSize
>
new_array
;
static_assert
(
NSize
==
sizeof
...(
IRs
),
"NSize not consistent"
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
.
Get
();
new_array
[
old2new
.
Get
(
IDim
)]
=
old_array
[
idim
];
});
return
new_array
;
}
\ No newline at end of file
src/include/ConstantTensorDescriptor.hip.hpp
View file @
19f17df4
...
@@ -108,11 +108,11 @@ template <class Lengths, class Strides>
...
@@ -108,11 +108,11 @@ template <class Lengths, class Strides>
struct
ConstantTensorDescriptor
struct
ConstantTensorDescriptor
{
{
using
Type
=
ConstantTensorDescriptor
<
Lengths
,
Strides
>
;
using
Type
=
ConstantTensorDescriptor
<
Lengths
,
Strides
>
;
static
constexpr
index_t
nDim
=
Lengths
::
nDim
;
static
constexpr
index_t
nDim
=
Lengths
::
GetSize
()
;
__host__
__device__
constexpr
ConstantTensorDescriptor
()
__host__
__device__
constexpr
ConstantTensorDescriptor
()
{
{
static_assert
(
Lengths
::
nDim
==
Strides
::
nDim
,
"nDim not consistent"
);
static_assert
(
Lengths
::
GetSize
()
==
Strides
::
GetSize
()
,
"nDim not consistent"
);
}
}
__host__
__device__
static
constexpr
index_t
GetDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
index_t
GetDimension
()
{
return
nDim
;
}
...
@@ -157,12 +157,10 @@ struct ConstantTensorDescriptor
...
@@ -157,12 +157,10 @@ struct ConstantTensorDescriptor
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
}
}
template
<
class
...
Is
>
template
<
index_t
NSize
>
__host__
__device__
static
index_t
Get1dIndex
(
Is
...
is
)
__host__
__device__
static
index_t
Get1dIndex
(
Array
<
index_t
,
NSize
>
multi_id
)
{
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
static_assert
(
NSize
==
nDim
,
"wrong! Dimension not consistent"
);
const
auto
multi_id
=
Array
<
index_t
,
nDim
>
(
is
...);
index_t
id
=
0
;
index_t
id
=
0
;
...
@@ -178,6 +176,16 @@ struct ConstantTensorDescriptor
...
@@ -178,6 +176,16 @@ struct ConstantTensorDescriptor
return
id
;
return
id
;
}
}
template
<
class
...
Is
>
__host__
__device__
static
index_t
Get1dIndex
(
Is
...
is
)
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
const
auto
multi_id
=
Array
<
index_t
,
nDim
>
(
is
...);
return
Get1dIndex
(
multi_id
);
}
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
{
{
Array
<
index_t
,
nDim
>
multi_id
;
Array
<
index_t
,
nDim
>
multi_id
;
...
...
src/include/Sequence.hip.hpp
View file @
19f17df4
...
@@ -7,9 +7,11 @@ struct Sequence
...
@@ -7,9 +7,11 @@ struct Sequence
{
{
using
Type
=
Sequence
<
Is
...
>
;
using
Type
=
Sequence
<
Is
...
>
;
static
constexpr
index_t
nDim
=
sizeof
...(
Is
);
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
const
index_t
mData
[
nDim
]
=
{
Is
...};
const
index_t
mData
[
mSize
]
=
{
Is
...};
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
Get
(
Number
<
I
>
)
const
__host__
__device__
constexpr
index_t
Get
(
Number
<
I
>
)
const
...
@@ -19,36 +21,38 @@ struct Sequence
...
@@ -19,36 +21,38 @@ struct Sequence
__host__
__device__
index_t
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
index_t
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
// this is ugly, only for nDIm = 4
template
<
index_t
...
IRs
>
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
const
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
{
static_assert
(
nDim
==
4
,
"nDim != 4"
);
static_assert
(
mSize
==
sizeof
...(
IRs
),
"mSize not consistent"
);
constexpr
auto
old_sequence
=
Type
{};
constexpr
index_t
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
auto
old
=
Type
{};
constexpr
index_t
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
index_t
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
index_t
NR3
=
old_sequence
.
mData
[
I3
];
return
Sequence
<
NR0
,
NR1
,
NR2
,
NR3
>
{};
return
Sequence
<
old
.
Get
(
Number
<
IRs
>
{})...
>
{};
}
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
template
<
index_t
...
IRs
>
__host__
__device__
constexpr
auto
Reorder
ByPut
Old
To
New
(
Sequence
<
I
0
,
I1
,
I2
,
I3
>
)
const
__host__
__device__
constexpr
auto
Reorder
Given
Old
2
New
(
Sequence
<
I
Rs
...
>
/*old2new*/
)
const
{
{
// don't know how to implement this
// don't know how to implement this
printf
(
"Sequence::Reorder
ByPut
Old
To
New not implemented"
);
printf
(
"Sequence::Reorder
Given
Old
2
New not implemented"
);
assert
(
false
);
assert
(
false
);
}
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushFront
(
Number
<
I
>
)
const
{
return
Sequence
<
I
,
Is
...
>
{};
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
{
{
return
Sequence
<
Is
...,
I
>
{};
return
Sequence
<
Is
...,
I
>
{};
}
}
__host__
__device__
constexpr
auto
PopFront
()
const
;
__host__
__device__
constexpr
auto
PopBack
()
const
;
__host__
__device__
constexpr
auto
PopBack
()
const
;
template
<
class
F
>
template
<
class
F
>
...
@@ -58,33 +62,84 @@ struct Sequence
...
@@ -58,33 +62,84 @@ struct Sequence
}
}
};
};
template
<
index_t
I
,
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_pop_front
(
Sequence
<
I
,
Is
...
>
)
{
static_assert
(
sizeof
...(
Is
)
>
0
,
"empty Sequence!"
);
return
Sequence
<
Is
...
>
{};
}
template
<
index_t
...
Is
,
index_t
I
>
template
<
index_t
...
Is
,
index_t
I
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
Is
...,
I
>
)
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
Is
...,
I
>
)
{
{
static_assert
(
sizeof
...(
Is
)
>
=
1
,
"empty Sequence!"
);
static_assert
(
sizeof
...(
Is
)
>
0
,
"empty Sequence!"
);
return
Sequence
<
Is
...
>
{};
return
Sequence
<
Is
...
>
{};
}
}
#if 1
// this is ugly, only for 2 sequences
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
sequence
_sequence
_op
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
F
f
)
__host__
__device__
constexpr
auto
transform
_sequence
s
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
static_assert
(
Sequence
<
Xs
...
>::
nDim
==
Sequence
<
Ys
...
>::
nDim
,
"Dim not the same"
);
static_assert
(
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Ys
...
>::
mSize
,
"Dim not the same"
);
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
// this is ugly, only for 3 sequences
__host__
__device__
constexpr
auto
sequence_sequence_add
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
Sequence
<
Zs
...
>
)
{
static_assert
(
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Ys
...
>::
mSize
&&
Sequence
<
Xs
...
>::
mSize
==
Sequence
<
Zs
...
>::
mSize
,
"Dim not the same"
);
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
#else
template
<
index_t
NRemain
>
struct
transform_sequences_impl
{
{
struct
add
template
<
class
F
,
class
Y
,
class
...
Xs
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
Y
y
,
Xs
...
xs
)
const
{
{
__host__
__device__
constexpr
index_t
operator
()(
index_t
x
,
index_t
y
)
const
static_assert
(
NRemain
>
1
,
"wrong! should have NRemain > 1"
);
constexpr
index_t
N
=
f
(
Xs
{}.
Get
(
Number
<
0
>
{})...);
constexpr
auto
y_new
=
y
.
PushBack
(
Number
<
N
>
{});
return
transform_sequences_impl
<
NRemain
-
1
>
{}(
f
,
y_new
,
xs
.
PopFront
()...);
}
};
template
<
>
struct
transform_sequences_impl
<
1
>
{
template
<
class
F
,
class
Y
,
class
...
Xs
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
Y
,
Xs
...)
const
{
{
return
x
+
y
;
constexpr
index_t
N
=
f
(
Xs
{}.
Get
(
Number
<
0
>
{})...);
return
Y
{}.
PushBack
(
Number
<
N
>
{});
}
}
};
};
return
sequence_sequence_op
(
Sequence
<
Xs
...
>
{},
Sequence
<
Ys
...
>
{},
add
{});
template
<
class
F
,
class
X
,
class
...
Xs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
X
x
,
Xs
...
xs
)
{
constexpr
index_t
nSize
=
X
::
GetSize
();
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
y0
=
Sequence
<
f
(
X
{}.
Get
(
I0
),
Xs
{}.
Get
(
I0
)...)
>
{};
return
transform_sequences_impl
<
nSize
-
1
>
{}(
f
,
y0
,
x
.
PopFront
(),
xs
.
PopFront
()...);
}
#endif
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopFront
()
const
{
return
sequence_pop_front
(
Type
{});
}
}
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
...
@@ -107,6 +162,6 @@ template <class Seq, class Reduce, index_t I>
...
@@ -107,6 +162,6 @@ template <class Seq, class Reduce, index_t I>
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
,
Number
<
I
>
)
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
,
Number
<
I
>
)
{
{
constexpr
index_t
a
=
constexpr
index_t
a
=
static_const_reduce_n
<
Seq
::
nDim
>
{}(
accumulate_on_sequence_f
<
Seq
>
{},
Reduce
{});
static_const_reduce_n
<
Seq
::
mSize
>
{}(
accumulate_on_sequence_f
<
Seq
>
{},
Reduce
{});
return
Reduce
{}(
a
,
I
);
return
Reduce
{}(
a
,
I
);
}
}
src/include/blockwise_2d_tensor_op.hip.hpp
View file @
19f17df4
...
@@ -67,7 +67,7 @@ template <index_t BlockSize,
...
@@ -67,7 +67,7 @@ template <index_t BlockSize,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
SrcOpLengths
,
class
DstFromSrcReorder
,
class
MapDst2Src
,
class
F
>
class
F
>
__device__
void
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
__device__
void
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
SrcDesc
,
...
@@ -75,14 +75,14 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
...
@@ -75,14 +75,14 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
,
MapDst2Src
,
F
f
)
F
f
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
...
@@ -147,19 +147,19 @@ template <index_t BlockSize,
...
@@ -147,19 +147,19 @@ template <index_t BlockSize,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
class
MapDst2Src
>
__device__
void
__device__
void
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
)
MapDst2Src
)
{
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
@@ -192,7 +192,7 @@ struct Blockwise2dTensorCopy1
...
@@ -192,7 +192,7 @@ struct Blockwise2dTensorCopy1
// but we need to make sure dst stride0 is big enough,
// but we need to make sure dst stride0 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
read_per_d1
=
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
index_t
read_per_d1
=
mod_conv
::
integer_divide_ceil
(
L1
,
DataPerRead
);
static_assert
(
read_per_d1
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I0
),
static_assert
(
read_per_d1
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I0
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
"wrong! out-of-bound write will contaminate next line!
\n
"
);
...
@@ -209,7 +209,7 @@ struct Blockwise2dTensorCopy1
...
@@ -209,7 +209,7 @@ struct Blockwise2dTensorCopy1
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
read_per_d1
=
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
index_t
read_per_d1
=
mod_conv
::
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
read_per_d1
>
{});
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
read_per_d1
>
{});
...
...
src/include/blockwise_3d_tensor_op.hip.hpp
View file @
19f17df4
...
@@ -33,7 +33,7 @@ struct Blockwise3dTensorCopy1
...
@@ -33,7 +33,7 @@ struct Blockwise3dTensorCopy1
// but we need to make sure dst stride2 is big enough,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
read_per_d2
=
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
index_t
read_per_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
DataPerRead
);
static_assert
(
read_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
static_assert
(
read_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
"wrong! out-of-bound write will contaminate next line!
\n
"
);
...
@@ -52,7 +52,7 @@ struct Blockwise3dTensorCopy1
...
@@ -52,7 +52,7 @@ struct Blockwise3dTensorCopy1
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
read_per_d2
=
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
index_t
read_per_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
read_per_d2
>
{});
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
read_per_d2
>
{});
...
@@ -98,3 +98,271 @@ struct Blockwise3dTensorCopy1
...
@@ -98,3 +98,271 @@ struct Blockwise3dTensorCopy1
}
}
}
}
};
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
index_t
DataPerRead
>
struct
Blockwise3dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise3dTensorCopy3
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I2
)
==
1
&&
DstDesc
{}.
GetStride
(
I2
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
,
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
// we allow out-of-bound read from src in D2 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
nloop_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
static_assert
(
nloop_d2
*
thread_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
L0
%
thread_per_d0
==
0
&&
L1
%
thread_per_d1
==
0
,
"wrong! L0, L1, L2 should be divided evenly!
\n
"
);
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
accumulate_on_sequence
(
ThreadPerDims
{},
mod_conv
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor
(
ThreadPerDims
{});
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndex
(
get_thread_local_1d_id
());
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
]
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
Get1dIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
]
*
DataPerRead
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
__device__
constexpr
index_t
GetRegisterClipboardSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
Get1dIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
mod_conv
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
clipboard_offset
=
clipboard_desc
.
Get1dIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]));
}
}
}
}
};
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
19f17df4
This diff is collapsed.
Click to expand it.
src/include/common.hip.hpp
View file @
19f17df4
...
@@ -25,12 +25,38 @@ struct is_same<T, T>
...
@@ -25,12 +25,38 @@ struct is_same<T, T>
static
const
bool
value
=
true
;
static
const
bool
value
=
true
;
};
};
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
namespace
mod_conv
{
// namespace mod_conv
template
<
class
T
>
struct
multiplies
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
template
<
class
T
>
struct
plus
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
class
T
>
struct
integer_divide_ceiler
{
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
index_t
>::
value
||
is_same
<
T
,
int
>::
value
,
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
}
};
template
<
class
T
>
__host__
__device__
constexpr
T
integer_divide_ceil
(
T
a
,
T
b
)
{
static_assert
(
is_same
<
T
,
index_t
>::
value
||
is_same
<
T
,
int
>::
value
,
"wrong type"
);
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
namespace
mod_conv
{
// namespace mod_conv
template
<
class
T
>
template
<
class
T
>
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
{
{
...
...
src/include/functional.hip.hpp
View file @
19f17df4
...
@@ -70,18 +70,3 @@ __host__ __device__ constexpr auto unpacker(F f)
...
@@ -70,18 +70,3 @@ __host__ __device__ constexpr auto unpacker(F f)
return [=](auto xs_array){ f(xs...); };
return [=](auto xs_array){ f(xs...); };
}
}
#endif
#endif
namespace
mod_conv
{
template
<
class
T
>
struct
multiplies
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
template
<
class
T
>
struct
plus
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
}
// namespace mod_conv
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
View file @
19f17df4
...
@@ -248,42 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -248,42 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
}
}
}
}
// output: register to global mem,
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
@@ -331,6 +296,5 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -331,6 +296,5 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
0 → 100644
View file @
19f17df4
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_3d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp"
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyThreadPerDims
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_k_h_w_n_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_k_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_k_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_k_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_k_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_x_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
X
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
K
,
1
>
{});
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_c_x_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
auto
map_chwn2nchw
=
Sequence
<
1
,
2
,
3
,
0
>
{};
#if 0
const auto blockwise_in_copy_reorder =
Blockwise4dTensorCopyReorder1<BlockSize,
Float,
decltype(in_n_c_h_w_global_desc),
decltype(in_c_h_w_n_block_desc),
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
decltype(map_chwn2nchw)>{};
#else
auto
map_thread_cluster_2_src_cluster
=
Sequence
<
1
,
2
,
0
,
3
>
{};
const
auto
blockwise_in_copy_reorder
=
Blockwise4dTensorCopyReorder3
<
BlockSize
,
Float
,
decltype
(
in_n_c_h_w_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
Sequence
<
NPerBlock
,
CPerBlock
,
HoPerBlock
,
WiPerBlock
>
,
Sequence
<
4
,
1
,
1
,
2
>
,
Sequence
<
4
,
8
,
2
,
2
>
,
decltype
(
map_chwn2nchw
),
decltype
(
map_thread_cluster_2_src_cluster
),
2
,
4
>
{};
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("size %u\n", blockwise_in_copy_reorder.GetRegisterClipboardSize());
}
#endif
#endif
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
#if 0
Blockwise3dTensorCopy1<BlockSize,
Float,
decltype(wei_c_x_k_global_desc),
decltype(wei_c_x_k_block_desc),
decltype(wei_c_x_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#else
Blockwise3dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_x_k_global_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
Sequence
<
4
,
1
,
32
>
,
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_x_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_x_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
Float
p_out_thread
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
n_block_data_begin
,
0
,
hi_block_data_begin
,
wi_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
blockwise_in_copy_reorder
.
Run
(
p_in_global_block_offset
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
0
,
0
,
y
,
0
),
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
p_wei_block
);
__syncthreads
();
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_c_x_k_block_desc
.
Get1dIndex
(
0
,
x
,
0
),
p_in_block
+
in_c_h_w_n_block_desc
.
Get1dIndex
(
0
,
0
,
x
,
0
),
p_out_thread
);
}
__syncthreads
();
}
}
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
-
NPerBlock
*
wo_thread_data_begin
;
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
}
};
src/include/threadwise_2d_tensor_op.hip.hpp
View file @
19f17df4
...
@@ -29,26 +29,21 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
...
@@ -29,26 +29,21 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
// TODO: in order to optimize mem access for different mem type,
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
// need to write specialized version
template
<
class
Float
,
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
,
class
F
>
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
,
class
F
>
__device__
void
threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
__device__
void
threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
SrcDesc
,
Float
*
const
__restrict__
p_src
,
Float
*
const
__restrict__
p_src
,
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
,
MapDst2Src
,
F
f
)
F
f
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
...
@@ -78,19 +73,19 @@ __device__ void threadwise_2d_tensor_set_zero(Desc, Float* __restrict__ p)
...
@@ -78,19 +73,19 @@ __device__ void threadwise_2d_tensor_set_zero(Desc, Float* __restrict__ p)
Desc
{},
p
,
f_set_zero
);
Desc
{},
p
,
f_set_zero
);
}
}
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
__device__
void
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
Float
*
const
__restrict__
p_src
,
Float
*
const
__restrict__
p_src
,
DstDesc
,
DstDesc
,
Float
*
__restrict__
p_dst
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
)
MapDst2Src
)
{
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
}
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
...
...
src/include/threadwise_4d_tensor_op.hip.hpp
View file @
19f17df4
...
@@ -42,7 +42,7 @@ template <class SrcData,
...
@@ -42,7 +42,7 @@ template <class SrcData,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
SrcOpLengths
,
class
DstFromSrcReorder
,
class
MapDst2Src
,
class
F
>
class
F
>
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
SrcDesc
,
...
@@ -50,7 +50,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
...
@@ -50,7 +50,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
DstDesc
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
,
MapDst2Src
,
F
f
)
F
f
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -58,10 +58,10 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
...
@@ -58,10 +58,10 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
DstFromSrcReorder
{}.
Get
(
I2
);
constexpr
index_t
IR2
=
MapDst2Src
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
DstFromSrcReorder
{}.
Get
(
I3
);
constexpr
index_t
IR3
=
MapDst2Src
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
...
@@ -82,7 +82,29 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
...
@@ -82,7 +82,29 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
const
index_t
bindex
=
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
#if 1
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
#else
if
(
get_block_1d_id
()
==
0
)
{
printf
(
"tid %5u, "
"src did %u %u %u %u, "
"dst did %u %u %u %u, "
"aindex %5u, "
"bindex %5u
\n
"
,
get_thread_local_1d_id
(),
did0
,
did1
,
did2
,
did3
,
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
],
aindex
,
bindex
);
}
#endif
}
}
}
}
}
}
...
@@ -103,19 +125,19 @@ template <class SrcData,
...
@@ -103,19 +125,19 @@ template <class SrcData,
class
SrcDesc
,
class
SrcDesc
,
class
DstDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
class
MapDst2Src
>
__device__
void
__device__
void
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
DstFromSrcReorder
)
MapDst2Src
)
{
{
auto
f_copy
=
[](
const
SrcData
&
src
,
DstData
&
dst
)
{
dst
=
static_cast
<
DstData
>
(
src
);
};
auto
f_copy
=
[](
const
SrcData
&
src
,
DstData
&
dst
)
{
dst
=
static_cast
<
DstData
>
(
src
);
};
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
}
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
...
@@ -137,13 +159,12 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
...
@@ -137,13 +159,12 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
SrcOpLengths
,
SrcOpLengths
,
Number
<
DataPerRead
>
)
Number
<
DataPerRead
>
)
{
{
using
Float2
=
float2
;
using
Float4
=
float4
;
static_assert
(
SrcDesc
{}.
GetDimension
()
==
4
&&
DstDesc
{}.
GetDimension
()
==
4
&&
static_assert
(
SrcDesc
{}.
GetDimension
()
==
4
&&
DstDesc
{}.
GetDimension
()
==
4
&&
SrcOpLengths
::
nDim
==
4
,
SrcOpLengths
::
GetSize
()
==
4
,
"wrong! should be 4 dimension"
);
"wrong! should be 4 dimension"
);
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -183,24 +204,8 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
...
@@ -183,24 +204,8 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
const
index_t
dst_index
=
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
iloop_d3
*
DataPerRead
);
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
iloop_d3
*
DataPerRead
);
if
(
DataPerRead
==
1
)
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]))
=
{
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
p_dst
[
dst_index
]
=
p_src
[
src_index
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
src_index
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
src_index
));
}
else
{
assert
(
false
);
}
}
}
}
}
}
}
...
...
src/include/threadwise_nd_tensor_op.hip.hpp
View file @
19f17df4
...
@@ -175,7 +175,7 @@ __device__ void threadwise_10d_tensor_copy(SrcDesc,
...
@@ -175,7 +175,7 @@ __device__ void threadwise_10d_tensor_copy(SrcDesc,
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
static_assert
(
SrcDesc
{}.
GetDimension
()
==
10
&&
DstDesc
{}.
GetDimension
()
==
10
&&
static_assert
(
SrcDesc
{}.
GetDimension
()
==
10
&&
DstDesc
{}.
GetDimension
()
==
10
&&
SrcOpLengths
::
nDim
==
10
,
SrcOpLengths
::
GetSize
()
==
10
,
"wrong! should be 10 dimension"
);
"wrong! should be 10 dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment