Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
766b0a9e
Commit
766b0a9e
authored
Mar 24, 2019
by
Chao Liu
Browse files
experimenting
parent
f35c64eb
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1413 additions
and
1376 deletions
+1413
-1376
driver/device_direct_convolution_1.hpp
driver/device_direct_convolution_1.hpp
+14
-14
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
+35
-35
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+80
-80
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+193
-193
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ice_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+154
-154
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+128
-98
driver/driver.hip.cpp
driver/driver.hip.cpp
+185
-185
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+4
-4
src/include/ConstantMatrixDescriptor.hip.hpp
src/include/ConstantMatrixDescriptor.hip.hpp
+14
-10
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+40
-36
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+17
-17
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+126
-128
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+105
-105
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+150
-150
src/include/blockwise_direct_convolution.hip.hpp
src/include/blockwise_direct_convolution.hip.hpp
+28
-28
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+132
-133
src/include/common.hip.hpp
src/include/common.hip.hpp
+3
-3
src/include/config.h.in
src/include/config.h.in
+2
-0
src/include/constant_integral.hip.hpp
src/include/constant_integral.hip.hpp
+2
-2
src/include/data_type.hip.hpp
src/include/data_type.hip.hpp
+1
-1
No files found.
driver/device_direct_convolution_1.hpp
View file @
766b0a9e
...
...
@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const
Tensor
<
T
>&
wei
,
OutDesc
,
Tensor
<
T
>&
out
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
...
...
@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1
// 3x3, 34x34
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
16
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_1
<
T
,
InDesc
,
...
...
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
View file @
766b0a9e
...
...
@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const
Tensor
<
T
>&
wei
,
OutDesc
,
Tensor
<
T
>&
out
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
...
...
@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1
// 3x3, 34x34, 128 thread
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_nchw_kcyx_nkhw
<
T
,
...
...
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
View file @
766b0a9e
...
...
@@ -10,10 +10,10 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const
Tensor
<
TInWei
>&
wei_kcyx
,
OutDesc
,
Tensor
<
TOut
>&
out_nkhw
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
// this suppose in / wei data type is int8x4
constexpr
unsigned
NVector
=
4
;
constexpr
index_t
NVector
=
4
;
using
accum_t
=
int32_t
;
using
vector_t
=
vector_type
<
TInWei
,
NVector
>
;
using
vector_mem_t
=
typename
vector_t
::
MemoryType
;
...
...
@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// vectorized input
auto
in_nchw_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
/
NVector
,
Hi
,
Wi
>
{});
...
...
@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr
unsigned
NPerBlock = 2;
constexpr
unsigned
KPerBlock = 32;
constexpr
unsigned
CPerBlock = 4;
constexpr
unsigned
HoPerBlock = 2;
constexpr
unsigned
WoPerBlock = 32;
constexpr
unsigned
NPerThread = 2;
constexpr
unsigned
KPerThread = 4;
constexpr
unsigned
CPerThread = 2;
constexpr
unsigned
HoPerThread = 2;
constexpr
unsigned
WoPerThread = 2;
constexpr
unsigned
InBlockCopyDataPerRead = 2;
constexpr
unsigned
WeiBlockCopyDataPerRead = 2;
constexpr
unsigned
BlockSize = 128;
constexpr
index_t
NPerBlock = 2;
constexpr
index_t
KPerBlock = 32;
constexpr
index_t
CPerBlock = 4;
constexpr
index_t
HoPerBlock = 2;
constexpr
index_t
WoPerBlock = 32;
constexpr
index_t
NPerThread = 2;
constexpr
index_t
KPerThread = 4;
constexpr
index_t
CPerThread = 2;
constexpr
index_t
HoPerThread = 2;
constexpr
index_t
WoPerThread = 2;
constexpr
index_t
InBlockCopyDataPerRead = 2;
constexpr
index_t
WeiBlockCopyDataPerRead = 2;
constexpr
index_t
BlockSize = 128;
#elif
0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
1
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
4
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr
unsigned
NPerBlock
=
1
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
16
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
1
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
4
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
1
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
<
TInWei
,
...
...
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
View file @
766b0a9e
...
...
@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
...
...
@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0
// for 3x3, 34x34
constexpr
unsigned
NPerBlock = 16;
constexpr
unsigned
KPerBlock = 64;
constexpr
unsigned
CPerBlock = 4;
constexpr
unsigned
HoPerBlock = 2;
constexpr
unsigned
WoPerBlock = 4;
constexpr
unsigned
NPerThread = 8;
constexpr
unsigned
KPerThread = 8;
constexpr
unsigned
HoPerThread = 1;
constexpr
unsigned
WoPerThread = 1;
constexpr
unsigned
InBlockCopy_ThreadPerDimC = 4;
constexpr
unsigned
InBlockCopy_ThreadPerDimH = 4;
constexpr
unsigned
InBlockCopy_ThreadPerDimW = 2;
constexpr
unsigned
InBlockCopy_ThreadPerDimN = 4;
constexpr
unsigned
InBlockCopyDataPerRead = 4;
constexpr
unsigned
WeiBlockCopyDataPerRead = 4;
constexpr
unsigned
GemmMPerThreadSubC = 4;
constexpr
unsigned
GemmNPerThreadSubC = 4;
constexpr
unsigned
GemmMLevel0Cluster = 4;
constexpr
unsigned
GemmNLevel0Cluster = 2;
constexpr
unsigned
GemmMLevel1Cluster = 2;
constexpr
unsigned
GemmNLevel1Cluster = 4;
constexpr
unsigned
GemmKPerThreadLoop = 1;
constexpr
unsigned
OutThreadCopyDataPerWrite = 2;
constexpr
unsigned
BlockSize = 128;
constexpr
index_t
NPerBlock = 16;
constexpr
index_t
KPerBlock = 64;
constexpr
index_t
CPerBlock = 4;
constexpr
index_t
HoPerBlock = 2;
constexpr
index_t
WoPerBlock = 4;
constexpr
index_t
NPerThread = 8;
constexpr
index_t
KPerThread = 8;
constexpr
index_t
HoPerThread = 1;
constexpr
index_t
WoPerThread = 1;
constexpr
index_t
InBlockCopy_ThreadPerDimC = 4;
constexpr
index_t
InBlockCopy_ThreadPerDimH = 4;
constexpr
index_t
InBlockCopy_ThreadPerDimW = 2;
constexpr
index_t
InBlockCopy_ThreadPerDimN = 4;
constexpr
index_t
InBlockCopyDataPerRead = 4;
constexpr
index_t
WeiBlockCopyDataPerRead = 4;
constexpr
index_t
GemmMPerThreadSubC = 4;
constexpr
index_t
GemmNPerThreadSubC = 4;
constexpr
index_t
GemmMLevel0Cluster = 4;
constexpr
index_t
GemmNLevel0Cluster = 2;
constexpr
index_t
GemmMLevel1Cluster = 2;
constexpr
index_t
GemmNLevel1Cluster = 4;
constexpr
index_t
GemmKPerThreadLoop = 1;
constexpr
index_t
OutThreadCopyDataPerWrite = 2;
constexpr
index_t
BlockSize = 128;
#elif
0
// for 5x5, 36x36
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
unsigned
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
OutThreadCopyDataPerWrite
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
// not used, yet
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
// not used, yet
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
1
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
// not used, yet
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
// not used, yet
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 28x28
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
unsigned
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
OutThreadCopyDataPerWrite
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// for 1x1, 14x14
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
unsigned
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
OutThreadCopyDataPerWrite
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn
<
GridSize
,
...
...
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
View file @
766b0a9e
...
...
@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor
<
T
>&
out_nkhw
,
LowerPads
,
UpperPads
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
...
...
@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
constexpr
unsigned
NPerBlock = 1;
constexpr
unsigned
KPerBlock = 1;
constexpr
unsigned
CPerBlock = 1;
constexpr
unsigned
HoPerBlock = 2;
constexpr
unsigned
WoPerBlock = 4;
constexpr
unsigned
NPerThread = 1;
constexpr
unsigned
KPerThread = 1;
constexpr
unsigned
CPerThread = 1;
constexpr
unsigned
HoPerThread = 1;
constexpr
unsigned
WoPerThread = 1;
constexpr
unsigned
WeiBlockCopyThreadPerDim0 = 1;
constexpr
unsigned
WeiBlockCopyThreadPerDim1 = 1;
constexpr
unsigned
BlockSize = 8;
constexpr
index_t
NPerBlock = 1;
constexpr
index_t
KPerBlock = 1;
constexpr
index_t
CPerBlock = 1;
constexpr
index_t
HoPerBlock = 2;
constexpr
index_t
WoPerBlock = 4;
constexpr
index_t
NPerThread = 1;
constexpr
index_t
KPerThread = 1;
constexpr
index_t
CPerThread = 1;
constexpr
index_t
HoPerThread = 1;
constexpr
index_t
WoPerThread = 1;
constexpr
index_t
WeiBlockCopyThreadPerDim0 = 1;
constexpr
index_t
WeiBlockCopyThreadPerDim1 = 1;
constexpr
index_t
BlockSize = 8;
#elif
1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 5x5, 36x36
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 7x7, 38x38
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
2
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
64
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
2
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
64
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
1
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 1x1, 28x28
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
<
GridSize
,
...
...
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
766b0a9e
...
...
@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
N
=
in_nchw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
N
=
in_nchw_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
unsigned
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// convert in_nchw to in_cnhw
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
...
...
@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr
unsigned
BPerBlock = 128;
constexpr
unsigned
KPerBlock = 64;
constexpr
unsigned
CPerBlock = 4;
constexpr
index_t
BPerBlock = 128;
constexpr
index_t
KPerBlock = 64;
constexpr
index_t
CPerBlock = 4;
constexpr
unsigned
BPerThread = 8;
constexpr
unsigned
KPerThread = 8;
constexpr
index_t
BPerThread = 8;
constexpr
index_t
KPerThread = 8;
constexpr
unsigned
GemmMPerThreadSubC = 4;
constexpr
unsigned
GemmNPerThreadSubC = 4;
constexpr
unsigned
GemmMLevel0Cluster = 4;
constexpr
unsigned
GemmNLevel0Cluster = 2;
constexpr
unsigned
GemmMLevel1Cluster = 2;
constexpr
unsigned
GemmNLevel1Cluster = 8;
constexpr
unsigned
GemmKPerThreadLoop = 1;
constexpr
index_t
GemmMPerThreadSubC = 4;
constexpr
index_t
GemmNPerThreadSubC = 4;
constexpr
index_t
GemmMLevel0Cluster = 4;
constexpr
index_t
GemmNLevel0Cluster = 2;
constexpr
index_t
GemmMLevel1Cluster = 2;
constexpr
index_t
GemmNLevel1Cluster = 8;
constexpr
index_t
GemmKPerThreadLoop = 1;
constexpr
unsigned
GemmThreadPerColumnPerCluster = 8;
constexpr
unsigned
GemmThreadPerRowPerCluster = 8;
constexpr
index_t
GemmThreadPerColumnPerCluster = 8;
constexpr
index_t
GemmThreadPerRowPerCluster = 8;
constexpr
unsigned
InBlockCopyThreadPerDim0 = 4;
constexpr
unsigned
InBlockCopyThreadPerDim1 = 16;
constexpr
index_t
InBlockCopyThreadPerDim0 = 4;
constexpr
index_t
InBlockCopyThreadPerDim1 = 16;
constexpr
unsigned
WeiBlockCopyThreadPerDim0 = 4;
constexpr
unsigned
WeiBlockCopyThreadPerDim1 = 16;
constexpr
index_t
WeiBlockCopyThreadPerDim0 = 4;
constexpr
index_t
WeiBlockCopyThreadPerDim1 = 16;
constexpr
unsigned
InBlockCopyDataPerRead = 4;
constexpr
unsigned
WeiBlockCopyDataPerRead = 4;
constexpr
index_t
InBlockCopyDataPerRead = 4;
constexpr
index_t
WeiBlockCopyDataPerRead = 4;
constexpr
unsigned
BlockSize = 128;
constexpr
index_t
BlockSize = 128;
#elif
0
// 1x1, 28x28, 64 threads
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
64
;
#elif
1
constexpr
index_t
BlockSize
=
64
;
#elif
0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
4
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// 1x1, 28x28, 256 thread
constexpr
unsigned
BPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif 1
// 1x1, 14x14, Vega 10
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
4
;
constexpr
unsigned
GemmMLevel1Cluster
=
4
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
128
;
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
...
...
@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
#if 1
...
...
driver/driver.hip.cpp
View file @
766b0a9e
...
...
@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template
<
class
...
Ts
>
double
operator
()(
Ts
...
Xs
)
const
{
std
::
array
<
unsigned
long
,
sizeof
...(
Ts
)
>
dims
=
{{
Xs
...}};
std
::
array
<
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
Xs
...}};
return
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
true
,
[](
bool
init
,
unsigned
long
x
)
->
int
{
return
init
!=
(
x
%
2
);
})
[](
bool
init
,
index_t
x
)
->
int
{
return
init
!=
(
x
%
2
);
})
?
1
:
-
1
;
}
...
...
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
std
::
initializer_list
<
unsigned
>
lengths
=
{
std
::
initializer_list
<
index_t
>
lengths
=
{
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I1
),
desc
.
GetLength
(
I2
),
desc
.
GetLength
(
I3
)};
std
::
initializer_list
<
unsigned
>
strides
=
{
std
::
initializer_list
<
index_t
>
strides
=
{
desc
.
GetStride
(
I0
),
desc
.
GetStride
(
I1
),
desc
.
GetStride
(
I2
),
desc
.
GetStride
(
I3
)};
return
TensorDescriptor
(
lengths
,
strides
);
...
...
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads
,
UpperPads
)
{
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
...
...
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std
::
size_t
HO
=
out_nkhw
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
WO
=
out_nkhw
.
mDesc
.
GetLengths
()[
3
];
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
std
::
size_t
HiPerTile
=
HoPerTile
+
Y
-
1
;
std
::
size_t
WiPerTile
=
WoPerTile
+
X
-
1
;
...
...
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
{
#if 0
constexpr
unsigned
N = 1;
constexpr
unsigned
C = 1;
constexpr
unsigned
HI = 28;
constexpr
unsigned
WI = 28;
constexpr
unsigned
K = 1;
constexpr
unsigned
Y = 3;
constexpr
unsigned
X = 3;
constexpr
unsigned
HPad = 0;
constexpr
unsigned
WPad = 0;
constexpr
index_t
N = 1;
constexpr
index_t
C = 1;
constexpr
index_t
HI = 28;
constexpr
index_t
WI = 28;
constexpr
index_t
K = 1;
constexpr
index_t
Y = 3;
constexpr
index_t
X = 3;
constexpr
index_t
HPad = 0;
constexpr
index_t
WPad = 0;
#elif
0
// 3x3, 34x34
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
34
;
constexpr
unsigned
WI
=
34
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3, 56x56
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
unsigned
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
// 3x3, 58x58
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
unsigned
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
// 5x5, 36x36
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
36
;
constexpr
unsigned
WI
=
36
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
Y
=
5
;
constexpr
unsigned
X
=
5
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
36
;
constexpr
index_t
WI
=
36
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 7x7, 38x38
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
38
;
constexpr
unsigned
WI
=
38
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
Y
=
7
;
constexpr
unsigned
X
=
7
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
38
;
constexpr
index_t
WI
=
38
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3, 58x58
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 1x1 filter, 28x28 image
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
20
;
constexpr
unsigned
WI
=
84
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
20
;
constexpr
index_t
WI
=
84
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
64
;
constexpr
unsigned
HI
=
112
;
constexpr
unsigned
WI
=
112
;
constexpr
unsigned
K
=
128
;
constexpr
unsigned
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
112
;
constexpr
index_t
WI
=
112
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
20
;
constexpr
unsigned
WI
=
86
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
Y
=
5
;
constexpr
unsigned
X
=
5
;
constexpr
unsigned
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
20
;
constexpr
index_t
WI
=
86
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
constexpr
index_t
HPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
192
;
constexpr
unsigned
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
unsigned
K
=
32
;
constexpr
unsigned
Y
=
5
;
constexpr
unsigned
X
=
5
;
constexpr
unsigned
HPad
=
2
;
constexpr
unsigned
WPad
=
2
;
constexpr
index_t
N
=
16
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
constexpr
index_t
HPad
=
2
;
constexpr
index_t
WPad
=
2
;
#elif 0
// 1x1 filter, 32x32 image
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
32
;
constexpr
unsigned
WI
=
32
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
32
;
constexpr
index_t
WI
=
32
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 1x1 filter, 14x14 image
constexpr
unsigned
N
=
128
;
constexpr
unsigned
C
=
2048
;
constexpr
unsigned
HI
=
14
;
constexpr
unsigned
WI
=
14
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
// 1x1 filter, 14x14 image
, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 1
// 1x1 filter, 14x14 image, C = 512
constexpr
unsigned
N
=
128
;
constexpr
unsigned
C
=
512
;
constexpr
unsigned
HI
=
14
;
constexpr
unsigned
WI
=
14
;
constexpr
unsigned
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
...
...
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
}
bool
do_verification
=
atoi
(
argv
[
1
]);
unsigned
nrepeat
=
atoi
(
argv
[
2
]);
index_t
nrepeat
=
atoi
(
argv
[
2
]);
if
(
do_verification
)
{
...
...
src/include/Array.hip.hpp
View file @
766b0a9e
#pragma once
template
<
class
TData
,
unsigned
NSize
>
template
<
class
TData
,
index_t
NSize
>
struct
Array
{
using
Type
=
Array
<
TData
,
NSize
>
;
static
constexpr
unsigned
nSize
=
NSize
;
static
constexpr
index_t
nSize
=
NSize
;
unsigned
mData
[
nSize
];
index_t
mData
[
nSize
];
template
<
class
...
Xs
>
__host__
__device__
Array
(
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
xs
)...}
{
}
__host__
__device__
TData
operator
[](
unsigned
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
TData
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
};
src/include/ConstantMatrixDescriptor.hip.hpp
View file @
766b0a9e
#pragma once
#include "common.hip.hpp"
template
<
unsigned
NRow_
,
unsigned
NCol_
,
unsigned
RowStride_
>
template
<
index_t
NRow_
,
index_t
NCol_
,
index_t
RowStride_
>
struct
ConstantMatrixDescriptor
{
__host__
__device__
constexpr
ConstantMatrixDescriptor
()
...
...
@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert
(
NCol_
<=
RowStride_
,
"wrong! NCol > RowStride!"
);
}
__host__
__device__
constexpr
unsigned
NRow
()
const
{
return
NRow_
;
}
__host__
__device__
constexpr
index_t
NRow
()
const
{
return
NRow_
;
}
__host__
__device__
constexpr
unsigned
NCol
()
const
{
return
NCol_
;
}
__host__
__device__
constexpr
index_t
NCol
()
const
{
return
NCol_
;
}
__host__
__device__
constexpr
unsigned
RowStride
()
const
{
return
RowStride_
;
}
__host__
__device__
constexpr
index_t
RowStride
()
const
{
return
RowStride_
;
}
__host__
__device__
constexpr
auto
GetLengths
()
const
{
return
Sequence
<
NRow_
,
NCol_
>
{};
}
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
{
return
NRow_
*
NCol_
;
}
__host__
__device__
constexpr
index_t
GetElementSize
()
const
{
return
NRow_
*
NCol_
;
}
__host__
__device__
constexpr
unsigned
GetElementSpace
()
const
{
return
NRow_
*
RowStride_
;
}
__host__
__device__
constexpr
index_t
GetElementSpace
()
const
{
return
NRow_
*
RowStride_
;
}
__host__
__device__
unsigned
Get1dIndex
(
unsigned
irow
,
unsigned
icol
)
const
__host__
__device__
index_t
Get1dIndex
(
index_t
irow
,
index_t
icol
)
const
{
#if DEVICE_BACKEND_HIP
return
__mul24
(
irow
,
RowStride_
)
+
icol
;
#else
return
irow
*
RowStride_
+
icol
;
#endif
}
template
<
unsigned
SubNRow
,
unsigned
SubNCol
>
template
<
index_t
SubNRow
,
index_t
SubNCol
>
__host__
__device__
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
Number
<
SubNCol
>
)
const
{
...
...
@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
}
};
template
<
unsigned
NRow
,
unsigned
NCol
>
template
<
index_t
NRow
,
index_t
NCol
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
)
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
}
template
<
unsigned
NRow
,
unsigned
NCol
,
unsigned
RowStride
>
template
<
index_t
NRow
,
index_t
NCol
,
index_t
RowStride
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
,
Number
<
RowStride
>
)
{
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
766b0a9e
...
...
@@ -2,35 +2,35 @@
#include "common.hip.hpp"
// this is ugly, only for 2d
template
<
unsigned
L0
,
unsigned
L1
>
template
<
index_t
L0
,
index_t
L1
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
>
)
{
return
Sequence
<
L1
,
1
>
{};
}
// this is ugly, only for 4d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
)
{
return
Sequence
<
L1
*
L2
*
L3
,
L2
*
L3
,
L3
,
1
>
{};
}
// this is ugly, only for 6d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
,
unsigned
L4
,
unsigned
L5
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
L4
,
index_t
L5
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
>
)
{
return
Sequence
<
L1
*
L2
*
L3
*
L4
*
L5
,
L2
*
L3
*
L4
*
L5
,
L3
*
L4
*
L5
,
L4
*
L5
,
L5
,
1
>
{};
}
// this is ugly, only for 8d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
,
unsigned
L4
,
unsigned
L5
,
unsigned
L6
,
unsigned
L7
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
L4
,
index_t
L5
,
index_t
L6
,
index_t
L7
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
,
L6
,
L7
>
)
{
...
...
@@ -45,20 +45,20 @@ __host__ __device__ constexpr auto
}
// this is ugly, only for 2d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
Align
>
template
<
index_t
L0
,
index_t
L1
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_default_strides_aligned
(
Sequence
<
L0
,
L1
>
,
Number
<
Align
>
)
{
constexpr
unsigned
L1_align
=
Align
*
((
L1
+
Align
-
1
)
/
Align
);
constexpr
index_t
L1_align
=
Align
*
((
L1
+
Align
-
1
)
/
Align
);
return
Sequence
<
L1_align
,
1
>
{};
}
// this is ugly, only for 4d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
,
unsigned
Align
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_default_strides_aligned
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
,
Number
<
Align
>
)
{
constexpr
unsigned
L3_align
=
Align
*
((
L3
+
Align
-
1
)
/
Align
);
constexpr
index_t
L3_align
=
Align
*
((
L3
+
Align
-
1
)
/
Align
);
return
Sequence
<
L1
*
L2
*
L3_align
,
L2
*
L3_align
,
L3_align
,
1
>
{};
}
...
...
@@ -66,27 +66,27 @@ template <class Lengths, class Strides>
struct
ConstantTensorDescriptor
{
using
Type
=
ConstantTensorDescriptor
<
Lengths
,
Strides
>
;
static
constexpr
unsigned
nDim
=
Lengths
::
nDim
;
static
constexpr
index_t
nDim
=
Lengths
::
nDim
;
__host__
__device__
constexpr
ConstantTensorDescriptor
()
{
static_assert
(
Lengths
::
nDim
==
Strides
::
nDim
,
"nDim not consistent"
);
}
__host__
__device__
constexpr
unsigned
GetDimension
()
const
{
return
nDim
;
}
__host__
__device__
constexpr
index_t
GetDimension
()
const
{
return
nDim
;
}
__host__
__device__
constexpr
Lengths
GetLengths
()
const
{
return
Lengths
{};
}
__host__
__device__
constexpr
Strides
GetStrides
()
const
{
return
Strides
{};
}
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
GetLength
(
Number
<
I
>
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
GetLength
(
Number
<
I
>
)
const
{
return
Lengths
{}.
Get
(
Number
<
I
>
{});
}
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
GetStride
(
Number
<
I
>
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
GetStride
(
Number
<
I
>
)
const
{
return
Strides
{}.
Get
(
Number
<
I
>
{});
}
...
...
@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct
GetElementSize_f
{
template
<
class
IDim
>
__host__
__device__
constexpr
unsigned
operator
()(
IDim
idim
)
const
__host__
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
return
Type
{}.
GetLength
(
idim
);
}
};
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
__host__
__device__
constexpr
index_t
GetElementSize
()
const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
multiply
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
a
,
unsigned
b
)
const
__host__
__device__
constexpr
index_t
operator
()(
index_t
a
,
index_t
b
)
const
{
return
a
*
b
;
}
...
...
@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct
GetElementSpace_f
{
template
<
class
IDim
>
__host__
__device__
constexpr
unsigned
operator
()(
IDim
idim
)
const
__host__
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
return
(
Type
{}.
GetLength
(
idim
)
-
1
)
*
Type
{}.
GetStride
(
idim
);
}
};
template
<
class
Align
=
Number
<
1
>
>
__host__
__device__
constexpr
unsigned
GetElementSpace
(
Align
align
=
Align
{})
const
__host__
__device__
constexpr
index_t
GetElementSpace
(
Align
align
=
Align
{})
const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
add
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
a
,
unsigned
b
)
const
__host__
__device__
constexpr
index_t
operator
()(
index_t
a
,
index_t
b
)
const
{
return
a
+
b
;
}
...
...
@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
}
template
<
class
...
Is
>
__host__
__device__
unsigned
Get1dIndex
(
Is
...
is
)
const
__host__
__device__
index_t
Get1dIndex
(
Is
...
is
)
const
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
const
auto
multi_id
=
Array
<
unsigned
,
nDim
>
(
is
...);
const
auto
multi_id
=
Array
<
index_t
,
nDim
>
(
is
...);
unsigned
id
=
0
;
index_t
id
=
0
;
static_loop_n
<
nDim
>
{}([
&
](
auto
IDim
)
{
constexpr
unsigned
idim
=
IDim
.
Get
();
constexpr
index_t
idim
=
IDim
.
Get
();
#if DEVICE_BACKEND_HIP
id
+=
__mul24
(
multi_id
[
idim
],
GetStride
(
IDim
));
#else
id
+=
multi_id
[
idim
]
*
GetStride
(
IDim
);
#endif
});
return
id
;
...
...
@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
}
template
<
unsigned
IDim
,
unsigned
NVector
>
template
<
index_t
IDim
,
index_t
NVector
>
__host__
__device__
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
NVector
>
)
const
{
assert
(
false
);
// not implemented
...
...
@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return
ConstantTensorDescriptor
<
Lengths
,
Strides
>
{};
}
template
<
class
Lengths
,
unsigned
Align
>
template
<
class
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
make_ConstantTensorDescriptor_aligned
(
Lengths
,
Number
<
Align
>
)
{
using
Strides
=
decltype
(
calculate_default_strides_aligned
(
Lengths
{},
Number
<
Align
>
{}));
...
...
@@ -194,7 +198,7 @@ template <class TDesc>
__host__
__device__
void
print_ConstantTensorDescriptor
(
TDesc
,
const
char
*
s
)
{
constexpr
auto
desc
=
TDesc
{};
constexpr
unsigned
ndim
=
desc
.
GetDimension
();
constexpr
index_t
ndim
=
desc
.
GetDimension
();
static_assert
(
ndim
>=
2
&&
ndim
<=
8
,
"wrong!"
);
...
...
src/include/Sequence.hip.hpp
View file @
766b0a9e
...
...
@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
template
<
unsigned
...
Is
>
template
<
index_t
...
Is
>
struct
Sequence
{
using
Type
=
Sequence
<
Is
...
>
;
static
constexpr
unsigned
nDim
=
sizeof
...(
Is
);
static
constexpr
index_t
nDim
=
sizeof
...(
Is
);
const
unsigned
mData
[
nDim
]
=
{
Is
...};
const
index_t
mData
[
nDim
]
=
{
Is
...};
template
<
unsigned
I
>
__host__
__device__
constexpr
unsigned
Get
(
Number
<
I
>
)
const
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
Get
(
Number
<
I
>
)
const
{
return
mData
[
I
];
}
// this is ugly, only for nDIm = 4
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
static_assert
(
nDim
==
4
,
"nDim != 4"
);
constexpr
auto
old_sequence
=
Type
{};
constexpr
unsigned
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
unsigned
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
unsigned
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
unsigned
NR3
=
old_sequence
.
mData
[
I3
];
constexpr
index_t
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
index_t
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
index_t
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
index_t
NR3
=
old_sequence
.
mData
[
I3
];
return
Sequence
<
NR0
,
NR1
,
NR2
,
NR3
>
{};
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
ReorderByPutOldToNew
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
// don't know how to implement this
...
...
@@ -41,7 +41,7 @@ struct Sequence
assert
(
false
);
}
template
<
unsigned
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
{
return
Sequence
<
Is
...,
I
>
{};
...
...
@@ -56,14 +56,14 @@ struct Sequence
}
};
template
<
unsigned
...
Is
,
unsigned
I
>
template
<
index_t
...
Is
,
index_t
I
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
Is
...,
I
>
)
{
static_assert
(
sizeof
...(
Is
)
>=
1
,
"empty Sequence!"
);
return
Sequence
<
Is
...
>
{};
}
template
<
class
F
,
unsigned
...
Xs
,
unsigned
...
Ys
>
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
sequence_sequence_op
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
F
f
)
{
static_assert
(
Sequence
<
Xs
...
>::
nDim
==
Sequence
<
Ys
...
>::
nDim
,
"Dim not the same"
);
...
...
@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
template
<
unsigned
...
Xs
,
unsigned
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
sequence_sequence_add
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
struct
add
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
x
,
unsigned
y
)
const
__host__
__device__
constexpr
index_t
operator
()(
index_t
x
,
index_t
y
)
const
{
return
x
+
y
;
}
...
...
@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return
sequence_sequence_op
(
Sequence
<
Xs
...
>
{},
Sequence
<
Ys
...
>
{},
add
{});
}
template
<
unsigned
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopBack
()
const
{
return
sequence_pop_back
(
Type
{});
...
...
src/include/blockwise_2d_tensor_op.hip.hpp
View file @
766b0a9e
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_2d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
...
...
@@ -20,19 +20,19 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
}
#endif
constexpr
unsigned
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
const
unsigned
did0
=
is
/
desc
.
GetStride
(
I0
);
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
unsigned
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
f
(
p_dst
[
dindex
]);
}
...
...
@@ -41,17 +41,17 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
unsigned
did0
=
is
/
desc
.
GetStride
(
I0
);
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
unsigned
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
f
(
p_dst
[
dindex
]);
}
...
...
@@ -61,7 +61,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
...
...
@@ -80,20 +80,20 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
unsigned
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
unsigned
did
[
2
];
index_t
did
[
2
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -101,9 +101,9 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
unsigned
aindex
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]);
const
index_t
aindex
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
]);
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
...
...
@@ -112,11 +112,11 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
unsigned
did
[
2
];
index_t
did
[
2
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -124,16 +124,16 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
unsigned
aindex
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]);
const
index_t
aindex
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
]);
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
}
}
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
>
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_2d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
...
...
@@ -141,7 +141,7 @@ __device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_2d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
...
...
@@ -161,7 +161,7 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
}
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
struct
Blockwise2dTensorCopy1
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
...
...
@@ -175,17 +175,17 @@ struct Blockwise2dTensorCopy1
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
unsigned
ThreadPerDim0
,
unsigned
ThreadPerDim1
>
index_t
ThreadPerDim0
,
index_t
ThreadPerDim1
>
struct
Blockwise2dTensorCopy2
{
unsigned
mThreadId0
;
unsigned
mThreadId1
;
index_t
mThreadId0
;
index_t
mThreadId1
;
__device__
Blockwise2dTensorCopy2
()
{
...
...
@@ -222,61 +222,61 @@ struct Blockwise2dTensorCopy2
constexpr
bool
align_v2
=
src_desc
.
GetStride
(
I0
)
%
2
==
0
&&
dst_desc
.
GetStride
(
I0
)
%
2
==
0
;
constexpr
unsigned
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
SrcOpLengths
{}.
Get
(
I1
);
constexpr
index_t
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
SrcOpLengths
{}.
Get
(
I1
);
constexpr
unsigned
Dim0Loop
=
L0
/
ThreadPerDim0
;
constexpr
index_t
Dim0Loop
=
L0
/
ThreadPerDim0
;
constexpr
bool
d0_has_tail
=
(
L0
>
ThreadPerDim0
*
Dim0Loop
);
constexpr
unsigned
Dim1V4Loop
=
align_v4
?
L1
/
(
ThreadPerDim1
*
4
)
:
0
;
constexpr
index_t
Dim1V4Loop
=
align_v4
?
L1
/
(
ThreadPerDim1
*
4
)
:
0
;
constexpr
unsigned
Dim1V2Loop
=
constexpr
index_t
Dim1V2Loop
=
align_v2
?
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
))
/
(
ThreadPerDim1
*
2
)
:
0
;
constexpr
unsigned
Dim1V1Loop
=
constexpr
index_t
Dim1V1Loop
=
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
)
-
Dim1V2Loop
*
(
ThreadPerDim1
*
2
))
/
ThreadPerDim1
;
constexpr
bool
d1_has_tail
=
(
L1
>
ThreadPerDim1
*
(
4
*
Dim1V4Loop
+
2
*
Dim1V2Loop
+
Dim1V1Loop
));
for
(
unsigned
d0loop
=
0
;
d0loop
<
Dim0Loop
;
++
d0loop
)
for
(
index_t
d0loop
=
0
;
d0loop
<
Dim0Loop
;
++
d0loop
)
{
unsigned
did0
=
d0loop
*
ThreadPerDim0
+
mThreadId0
;
index_t
did0
=
d0loop
*
ThreadPerDim0
+
mThreadId0
;
// v4
for
(
unsigned
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
for
(
index_t
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
index_t
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
sindex
));
}
// v2
for
(
unsigned
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
for
(
index_t
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
unsigned
did1
=
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
sindex
));
}
// v1
for
(
unsigned
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
for
(
index_t
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
...
...
@@ -284,13 +284,13 @@ struct Blockwise2dTensorCopy2
// dim-1 tail
if
(
d1_has_tail
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
...
...
@@ -300,45 +300,44 @@ struct Blockwise2dTensorCopy2
// dim-0 tail
if
(
d0_has_tail
)
{
unsigned
did0
=
Dim0Loop
*
ThreadPerDim0
+
mThreadId0
;
index_t
did0
=
Dim0Loop
*
ThreadPerDim0
+
mThreadId0
;
if
(
did0
<
L0
)
{
// v4
for
(
unsigned
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
for
(
index_t
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
index_t
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
sindex
));
}
// v2
for
(
unsigned
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
for
(
index_t
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
sindex
));
}
// v1
for
(
unsigned
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
for
(
index_t
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
...
...
@@ -346,14 +345,13 @@ struct Blockwise2dTensorCopy2
// tail
if
(
d1_has_tail
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
...
...
@@ -365,18 +363,18 @@ struct Blockwise2dTensorCopy2
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
unsigned
DataPerRead
>
index_t
DataPerRead
>
struct
Blockwise2dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
unsigned
mSrcMyThreadOffset
;
unsigned
mDstMyThreadOffset
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise2dTensorCopy3
()
{
...
...
@@ -394,11 +392,11 @@ struct Blockwise2dTensorCopy3
DstDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
,
"src and dst stride should be multiple of DataPerRead to keep alignment"
);
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough,
...
...
@@ -408,7 +406,7 @@ struct Blockwise2dTensorCopy3
static_assert
(
thread_per_d0
>=
1
,
"wrong! not enough threads to cover one line
\n
"
);
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
...
...
@@ -418,8 +416,8 @@ struct Blockwise2dTensorCopy3
}
}
const
unsigned
thread_id_d0
=
get_thread_local_1d_id
()
/
thread_per_d1
;
const
unsigned
thread_id_d1
=
get_thread_local_1d_id
()
-
thread_id_d0
*
thread_per_d1
;
const
index_t
thread_id_d0
=
get_thread_local_1d_id
()
/
thread_per_d1
;
const
index_t
thread_id_d1
=
get_thread_local_1d_id
()
-
thread_id_d0
*
thread_per_d1
;
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
*
DataPerRead
);
...
...
@@ -430,13 +428,13 @@ struct Blockwise2dTensorCopy3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
...
...
@@ -446,18 +444,18 @@ struct Blockwise2dTensorCopy3
}
}
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
unsigned
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
unsigned
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
unsigned
iloop
)
{
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
};
for
(
unsigned
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
...
...
@@ -466,7 +464,7 @@ struct Blockwise2dTensorCopy3
if
(
has_tail_d0
)
{
constexpr
unsigned
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
...
...
@@ -475,18 +473,18 @@ struct Blockwise2dTensorCopy3
}
}
__device__
constexpr
unsigned
GetRegisterClipboardSize
()
const
__device__
constexpr
index_t
GetRegisterClipboardSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
return
DataPerRead
*
(
L0
+
thread_per_d0
-
1
)
/
thread_per_d0
;
}
...
...
@@ -497,13 +495,13 @@ struct Blockwise2dTensorCopy3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
...
...
@@ -513,18 +511,18 @@ struct Blockwise2dTensorCopy3
}
}
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
unsigned
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
unsigned
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
unsigned
iloop
)
{
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
p_clipboard
+
iloop
*
4
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
};
for
(
unsigned
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
...
...
@@ -533,7 +531,7 @@ struct Blockwise2dTensorCopy3
if
(
has_tail_d0
)
{
constexpr
unsigned
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
...
...
@@ -548,13 +546,13 @@ struct Blockwise2dTensorCopy3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
...
...
@@ -564,17 +562,17 @@ struct Blockwise2dTensorCopy3
}
}
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
unsigned
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
unsigned
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
unsigned
iloop
)
{
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_clipboard
+
iloop
*
4
));
};
for
(
unsigned
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
...
...
@@ -583,7 +581,7 @@ struct Blockwise2dTensorCopy3
if
(
has_tail_d0
)
{
constexpr
unsigned
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
766b0a9e
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
...
...
@@ -22,27 +22,27 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
}
#endif
constexpr
unsigned
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
const
unsigned
did0
=
is
/
desc
.
GetStride
(
I0
);
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
unsigned
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
unsigned
did2
=
is
/
desc
.
GetStride
(
I2
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
unsigned
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
...
...
@@ -51,25 +51,25 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
unsigned
did0
=
is
/
desc
.
GetStride
(
I0
);
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
unsigned
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
unsigned
did2
=
is
/
desc
.
GetStride
(
I2
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
unsigned
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
const
index_t
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
...
...
@@ -79,7 +79,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
...
...
@@ -100,22 +100,22 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
unsigned
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
unsigned
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
unsigned
IR2
=
DstFromSrcReorder
{}.
Get
(
I2
);
constexpr
unsigned
IR3
=
DstFromSrcReorder
{}.
Get
(
I3
);
constexpr
index_t
IR0
=
DstFromSrcReorder
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
DstFromSrcReorder
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
DstFromSrcReorder
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
DstFromSrcReorder
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
unsigned
did
[
4
];
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -131,9 +131,9 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
unsigned
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
...
...
@@ -142,11 +142,11 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
unsigned
did
[
4
];
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -162,16 +162,16 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
unsigned
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
}
}
template
<
unsigned
BlockSize
,
class
Float
,
class
DstDesc
>
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_4d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
...
...
@@ -179,7 +179,7 @@ __device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
blockwise_4d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
...
...
@@ -199,12 +199,12 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
}
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
unsigned
DataPerRead
>
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
...
...
@@ -230,8 +230,8 @@ struct Blockwise4dTensorCopy1
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
unsigned
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
unsigned
read_per_d3
=
integer_divide_ceil
(
L3
,
DataPerRead
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
integer_divide_ceil
(
L3
,
DataPerRead
);
static_assert
(
read_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
...
...
@@ -247,20 +247,20 @@ struct Blockwise4dTensorCopy1
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
unsigned
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
unsigned
read_per_d3
=
integer_divide_ceil
(
L3
,
DataPerRead
);
constexpr
index_t
read_per_d3
=
integer_divide_ceil
(
L3
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
L2
,
read_per_d3
>
{});
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
unsigned
is
)
{
unsigned
did
[
4
];
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -276,18 +276,18 @@ struct Blockwise4dTensorCopy1
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
src_index
=
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
const
unsigned
dst_index
=
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
...
...
@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
...
...
@@ -306,7 +306,7 @@ struct Blockwise4dTensorCopy1
}
};
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
...
...
@@ -315,15 +315,15 @@ template <unsigned BlockSize,
struct
BlockwiseChwnTensorCopyPadded
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
unsigned
c_block_data_begin
,
unsigned
ho_block_data_begin
,
unsigned
wo_block_data_begin
,
unsigned
n_block_data_begin
,
index_t
c_block_data_begin
,
index_t
ho_block_data_begin
,
index_t
wo_block_data_begin
,
index_t
n_block_data_begin
,
Float
*
__restrict__
p_dst
,
unsigned
h_block_pad_low
,
unsigned
w_block_pad_low
,
unsigned
h_block_pad_up
,
unsigned
w_block_pad_up
)
const
index_t
h_block_pad_low
,
index_t
w_block_pad_low
,
index_t
h_block_pad_up
,
index_t
w_block_pad_up
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -337,7 +337,7 @@ struct BlockwiseChwnTensorCopyPadded
constexpr
auto
h_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I0
);
constexpr
auto
w_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I1
);
constexpr
unsigned
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
p_src
+
...
...
@@ -368,11 +368,11 @@ struct BlockwiseChwnTensorCopyPadded
}
#endif
for
(
unsigned
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
unsigned
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
iloop
*
BlockSize
;
unsigned
did
[
4
];
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -388,7 +388,7 @@ struct BlockwiseChwnTensorCopyPadded
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
...
...
@@ -401,11 +401,11 @@ struct BlockwiseChwnTensorCopyPadded
if
(
has_tail
)
{
unsigned
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
index_t
is
=
threadIdx
.
x
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
unsigned
did
[
4
];
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
...
...
@@ -421,7 +421,7 @@ struct BlockwiseChwnTensorCopyPadded
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
unsigned
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
...
...
@@ -436,19 +436,19 @@ struct BlockwiseChwnTensorCopyPadded
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
unsigned
DataPerRead
>
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
unsigned
mSrcMyThreadOffset
;
unsigned
mDstMyThreadOffset
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise4dTensorCopy3
()
{
...
...
@@ -469,20 +469,20 @@ struct Blockwise4dTensorCopy3
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
unsigned
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
unsigned
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
unsigned
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
unsigned
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
unsigned
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
static_assert
(
nloop_d3
*
thread_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
...
...
@@ -493,7 +493,7 @@ struct Blockwise4dTensorCopy3
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
unsigned
num_active_thread
=
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
...
...
@@ -504,14 +504,14 @@ struct Blockwise4dTensorCopy3
}
}
const
unsigned
thread_id_d0
=
const
index_t
thread_id_d0
=
get_thread_local_1d_id
()
/
(
thread_per_d1
*
thread_per_d2
*
thread_per_d3
);
unsigned
itmp
=
get_thread_local_1d_id
()
-
index_t
itmp
=
get_thread_local_1d_id
()
-
thread_id_d0
*
(
thread_per_d1
*
thread_per_d2
*
thread_per_d3
);
const
unsigned
thread_id_d1
=
itmp
/
(
thread_per_d2
*
thread_per_d3
);
const
index_t
thread_id_d1
=
itmp
/
(
thread_per_d2
*
thread_per_d3
);
itmp
-=
thread_id_d1
*
(
thread_per_d2
*
thread_per_d3
);
const
unsigned
thread_id_d2
=
itmp
/
thread_per_d3
;
const
unsigned
thread_id_d3
=
itmp
-
thread_id_d2
*
thread_per_d3
;
const
index_t
thread_id_d2
=
itmp
/
thread_per_d3
;
const
index_t
thread_id_d3
=
itmp
-
thread_id_d2
*
thread_per_d3
;
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
thread_id_d0
,
thread_id_d1
,
thread_id_d2
,
thread_id_d3
*
DataPerRead
);
...
...
@@ -526,17 +526,17 @@ struct Blockwise4dTensorCopy3
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
unsigned
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
unsigned
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
unsigned
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
unsigned
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
unsigned
num_active_thread
=
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
...
...
@@ -547,30 +547,30 @@ struct Blockwise4dTensorCopy3
}
}
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
unsigned
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
unsigned
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
unsigned
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
#pragma unroll
for
(
unsigned
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
unsigned
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
unsigned
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
unsigned
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
unsigned
src_offset
=
const
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
unsigned
dst_offset
=
const
index_t
dst_offset
=
DstDesc
{}.
Get1dIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
...
...
src/include/blockwise_batched_gemm.hip.hpp
View file @
766b0a9e
#pragma once
#include "threadwise_gemm.hip.hpp"
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
BatchPerThread
,
unsigned
KPerThreadLoop
,
index_t
BlockMatrixStrideA
,
index_t
BlockMatrixStrideB
,
index_t
ThreadMatrixStrideC
,
index_t
BatchSize
,
index_t
BatchPerThread
,
index_t
KPerThreadLoop
,
bool
DistributeThreadAlongColumnFirst
>
struct
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
batch
;
unsigned
row
;
unsigned
col
;
index_t
batch
;
index_t
row
;
index_t
col
;
};
__device__
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
()
...
...
@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
...
...
@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
constexpr
unsigned
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
unsigned
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
index_t
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
index_t
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
// divide thread work
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"BatchSize % BatchPerThread != 0"
);
static_assert
(
MPerBlock
%
MPerThread
==
0
,
"MPerBlock % MPerThread != 0"
);
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"NPerBlock % NPerThread != 0"
);
constexpr
unsigned
BatchThreadWork
=
(
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
;
constexpr
unsigned
MThreadWork
=
(
MPerBlock
+
MPerThread
-
1
)
/
MPerThread
;
constexpr
unsigned
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
BatchThreadWork
=
(
BatchSize
+
BatchPerThread
-
1
)
/
BatchPerThread
;
constexpr
index_t
MThreadWork
=
(
MPerBlock
+
MPerThread
-
1
)
/
MPerThread
;
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
static_assert
(
BlockSize
==
BatchThreadWork
*
MThreadWork
*
NThreadWork
,
"wrong! wrong BlockSize"
);
...
...
@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
if
(
DistributeThreadAlongColumnFirst
)
{
// num of operations can be reduced
const
unsigned
b_work_id
=
thread_id
/
(
MThreadWork
*
NThreadWork
);
unsigned
itmp
=
thread_id
-
b_work_id
*
(
MThreadWork
*
NThreadWork
);
const
unsigned
m_work_id
=
itmp
/
NThreadWork
;
const
unsigned
n_work_id
=
itmp
-
m_work_id
*
NThreadWork
;
const
index_t
b_work_id
=
thread_id
/
(
MThreadWork
*
NThreadWork
);
index_t
itmp
=
thread_id
-
b_work_id
*
(
MThreadWork
*
NThreadWork
);
const
index_t
m_work_id
=
itmp
/
NThreadWork
;
const
index_t
n_work_id
=
itmp
-
m_work_id
*
NThreadWork
;
return
MatrixIndex
{
b_work_id
*
BatchPerThread
,
m_work_id
*
MPerThread
,
n_work_id
*
NPerThread
};
...
...
@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
batch_in_c
,
unsigned
m_in_c
,
unsigned
n_in_c
)
GetDistanceFromBeginOfThreadMatrixC
(
index_t
batch_in_c
,
index_t
m_in_c
,
index_t
n_in_c
)
{
return
MatrixIndex
{
batch_in_c
,
m_in_c
,
n_in_c
};
}
...
...
@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
constexpr
auto
a_thread_mtx
=
...
...
@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of a, b
threadwise_matrix_copy
(
a_block_mtx
,
...
...
@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
b_thread_mtx
.
GetLengths
());
// loop over batch
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
...
...
@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
}
};
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
unsigned
BlockMatrixStrideA
,
unsigned
BlockMatrixStrideB
,
unsigned
ThreadMatrixStrideC
,
unsigned
BatchSize
,
unsigned
MPerThreadSubC
,
unsigned
NPerThreadSubC
,
unsigned
MLevel0Cluster
,
unsigned
NLevel0Cluster
,
unsigned
MLevel1Cluster
,
unsigned
NLevel1Cluster
,
unsigned
KPerThreadLoop
,
unsigned
BatchPerThread
>
index_t
BlockMatrixStrideA
,
index_t
BlockMatrixStrideB
,
index_t
ThreadMatrixStrideC
,
index_t
BatchSize
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
,
index_t
BatchPerThread
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
batch
;
unsigned
row
;
unsigned
col
;
index_t
batch
;
index_t
row
;
index_t
col
;
};
__device__
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
()
...
...
@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"wrong! BatchSize is not dividable by BatchPerThread"
);
constexpr
unsigned
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
index_t
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
unsigned
ThreadPerLevel1Cluster
=
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
BatchThreadWork
*
ThreadPerLevel1Cluster
,
...
...
@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
unsigned
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
unsigned
NPerLevel1Cluster
=
N
/
NRepeat
;
constexpr
index_t
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
index_t
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
unsigned
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
...
...
@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
constexpr
unsigned
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
index_t
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
unsigned
ThreadPerLevel1Cluster
=
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
unsigned
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
unsigned
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
unsigned
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
index_t
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
index_t
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
unsigned
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
unsigned
level1_m_id
=
level1_id
/
NLevel1Cluster
;
unsigned
level1_n_id
=
level1_id
%
NLevel1Cluster
;
index_t
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1Cluster
;
unsigned
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
unsigned
level0_m_id
=
level0_id
/
NLevel0Cluster
;
unsigned
level0_n_id
=
level0_id
%
NLevel0Cluster
;
index_t
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
unsigned
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
batch_work_id
*
BatchPerThread
,
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
...
...
@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
batch_in_c
,
unsigned
m_in_c
,
unsigned
n_in_c
)
GetDistanceFromBeginOfThreadMatrixC
(
index_t
batch_in_c
,
index_t
m_in_c
,
index_t
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
unsigned
m_repeat
=
m_in_c
/
MPerThreadSubC
;
unsigned
n_repeat
=
n_in_c
/
NPerThreadSubC
;
index_t
m_repeat
=
m_in_c
/
MPerThreadSubC
;
index_t
n_repeat
=
n_in_c
/
NPerThreadSubC
;
unsigned
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
unsigned
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
index_t
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
index_t
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
batch_in_c
,
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
...
...
@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
...
...
@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
...
...
@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
...
...
@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
...
...
@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if
(
BlockMatrixStrideA
!=
0
)
{
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
...
...
@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if
(
BlockMatrixStrideB
!=
0
)
{
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
...
...
@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
...
...
@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
//#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
for
(
index_t
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
{
#if 1
for
(
unsigned
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
for
(
index_t
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
...
...
@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
for
(
index_t
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
for
(
index_t
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
...
...
@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
//#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(
unsigned
i = 0; i < c_thread_mtx.NRow(); ++i)
for(
index_t
i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(
unsigned
j = 0; j < c_thread_mtx.NCol(); ++j)
for(
index_t
j = 0; j < c_thread_mtx.NCol(); ++j)
{
const
unsigned
aindex =
const
index_t
aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const
unsigned
bindex = b_thread_mtx.Get1dIndex(k, j);
const
unsigned
cindex =
const
index_t
bindex = b_thread_mtx.Get1dIndex(k, j);
const
index_t
cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
...
...
@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
...
...
@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if
(
BlockMatrixStrideA
!=
0
)
{
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
for
(
index_t
i
=
0
;
i
<
a_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
for
(
index_t
j
=
0
;
j
<
a_thread_sub_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
...
...
@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if
(
BlockMatrixStrideB
!=
0
)
{
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
for
(
index_t
i
=
0
;
i
<
b_thread_sub_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
for
(
index_t
j
=
0
;
j
<
b_thread_sub_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
...
...
@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
// do last batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(
unsigned
i = 0; i < c_thread_mtx.NRow(); ++i)
for(
index_t
i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(
unsigned
j = 0; j < c_thread_mtx.NCol(); ++j)
for(
index_t
j = 0; j < c_thread_mtx.NCol(); ++j)
{
const
unsigned
aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const
unsigned
bindex = b_thread_mtx.Get1dIndex(k, j);
const
unsigned
cindex = c_thread_mtx.Get1dIndex(i, j) +
const
index_t
aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const
index_t
bindex = b_thread_mtx.Get1dIndex(k, j);
const
index_t
cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
...
...
@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
asm
volatile
(
"
\n
\
...
...
@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
template
<
class
BlockMatrixC
,
unsigned
BlockMatrixStrideC
,
class
FloatC
>
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
{
constexpr
auto
c_block_mtx
=
BlockMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
const
auto
c_thread_mtx_begin
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
unsigned
c_thread_offset
=
const
index_t
c_thread_offset
=
c_thread_mtx_begin
.
batch
*
BlockMatrixStrideC
+
c_block_mtx
.
Get1dIndex
(
c_thread_mtx_begin
.
row
,
c_thread_mtx_begin
.
col
);
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
c_thread_sub_mtx
,
...
...
src/include/blockwise_direct_convolution.hip.hpp
View file @
766b0a9e
...
...
@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
InBlockDesc
,
class
WeiBlockDesc
,
class
OutBlockDesc
,
unsigned
NPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
HoPerThread
,
unsigned
WoPerThread
>
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
>
__device__
void
blockwise_direct_convolution
(
InBlockDesc
,
Float
*
const
__restrict__
p_in_block
,
WeiBlockDesc
,
...
...
@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr
auto
wei_block_desc
=
WeiBlockDesc
{};
constexpr
auto
out_block_desc
=
OutBlockDesc
{};
constexpr
unsigned
Y
=
wei_block_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_block_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_block_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_block_desc
.
GetLength
(
I3
);
constexpr
unsigned
InTileSizeH
=
HoPerThread
+
Y
-
1
;
constexpr
unsigned
InTileSizeW
=
WoPerThread
+
X
-
1
;
constexpr
index_t
InTileSizeH
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
InTileSizeW
=
WoPerThread
+
X
-
1
;
// divide thread work
constexpr
unsigned
NThreadWork
=
(
out_block_desc
.
GetLength
(
I0
)
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
unsigned
KThreadWork
=
(
out_block_desc
.
GetLength
(
I1
)
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
unsigned
YThreadWork
=
(
out_block_desc
.
GetLength
(
I2
)
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
unsigned
XThreadWork
=
(
out_block_desc
.
GetLength
(
I3
)
+
WoPerThread
-
1
)
/
WoPerThread
;
constexpr
index_t
NThreadWork
=
(
out_block_desc
.
GetLength
(
I0
)
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
out_block_desc
.
GetLength
(
I1
)
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
YThreadWork
=
(
out_block_desc
.
GetLength
(
I2
)
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
XThreadWork
=
(
out_block_desc
.
GetLength
(
I3
)
+
WoPerThread
-
1
)
/
WoPerThread
;
#if 0
if(threadIdx.x == 0)
...
...
@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr
auto
out_thread_block_desc
=
make_ConstantTensorDescriptor
(
out_thread_desc
.
GetLengths
(),
out_block_desc
.
GetStrides
());
const
unsigned
thread_id
=
threadIdx
.
x
;
const
index_t
thread_id
=
threadIdx
.
x
;
for
(
unsigned
thread_work_id
=
thread_id
;
for
(
index_t
thread_work_id
=
thread_id
;
thread_work_id
<
NThreadWork
*
KThreadWork
*
YThreadWork
*
XThreadWork
;
thread_work_id
+=
BlockSize
)
{
unsigned
itmp
=
thread_work_id
;
unsigned
n_thread_work_id
=
itmp
/
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
index_t
itmp
=
thread_work_id
;
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
unsigned
k_thread_work_id
=
itmp
/
(
YThreadWork
*
XThreadWork
);
index_t
k_thread_work_id
=
itmp
/
(
YThreadWork
*
XThreadWork
);
itmp
-=
k_thread_work_id
*
(
YThreadWork
*
XThreadWork
);
unsigned
y_thread_work_id
=
itmp
/
XThreadWork
;
unsigned
x_thread_work_id
=
itmp
-
y_thread_work_id
*
XThreadWork
;
index_t
y_thread_work_id
=
itmp
/
XThreadWork
;
index_t
x_thread_work_id
=
itmp
-
y_thread_work_id
*
XThreadWork
;
unsigned
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
unsigned
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
unsigned
ho_thread_data_begin
=
y_thread_work_id
*
HoPerThread
;
unsigned
wo_thread_data_begin
=
x_thread_work_id
*
WoPerThread
;
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
index_t
ho_thread_data_begin
=
y_thread_work_id
*
HoPerThread
;
index_t
wo_thread_data_begin
=
x_thread_work_id
*
WoPerThread
;
unsigned
hi_thread_data_begin
=
ho_thread_data_begin
;
// minus padding
unsigned
wi_thread_data_begin
=
wo_thread_data_begin
;
// minus padding
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
// minus padding
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
// minus padding
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
...
...
@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread
,
out_thread_desc
.
GetLengths
());
for
(
unsigned
c_thread_data_begin
=
0
;
c_thread_data_begin
<
in_block_desc
.
GetLength
(
I1
);
for
(
index_t
c_thread_data_begin
=
0
;
c_thread_data_begin
<
in_block_desc
.
GetLength
(
I1
);
c_thread_data_begin
+=
CPerThread
)
{
// threadwise convolution
...
...
src/include/blockwise_gemm.hip.hpp
View file @
766b0a9e
#pragma once
#include "threadwise_gemm.hip.hpp"
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
bool
TransA
,
bool
TransB
,
bool
TransC
,
unsigned
KPerThreadLoop
,
unsigned
MThreadPerCluster
,
unsigned
NThreadPerCluster
,
index_t
KPerThreadLoop
,
index_t
MThreadPerCluster
,
index_t
NThreadPerCluster
,
bool
DistributeThreadAlongColumnFirst
>
struct
BlockwiseGemmBlockABlockBThreadC
{
unsigned
mMyThreadOffsetA
=
0
;
unsigned
mMyThreadOffsetB
=
0
;
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
unsigned
row
;
unsigned
col
;
index_t
row
;
index_t
col
;
};
__device__
BlockwiseGemmBlockABlockBThreadC
()
...
...
@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
const
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
if
(
TransA
&&
(
!
TransB
)
&&
(
!
TransC
))
...
...
@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! k dimension not consistent!"
);
constexpr
unsigned
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
unsigned
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
index_t
MPerBlock
=
a_block_mtx
.
NCol
();
constexpr
index_t
NPerBlock
=
b_block_mtx
.
NCol
();
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
// divide thread work
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
(
MPerBlock
%
(
MPerThread
*
MThreadPerCluster
)
==
0
,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0"
);
...
...
@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert
(
NPerBlock
%
(
NPerThread
*
NThreadPerCluster
)
==
0
,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0"
);
constexpr
unsigned
MClusterWork
=
constexpr
index_t
MClusterWork
=
(
MPerBlock
+
MPerThread
*
MThreadPerCluster
-
1
)
/
(
MPerThread
*
MThreadPerCluster
);
constexpr
unsigned
NClusterWork
=
constexpr
index_t
NClusterWork
=
(
NPerBlock
+
NPerThread
*
NThreadPerCluster
-
1
)
/
(
NPerThread
*
NThreadPerCluster
);
static_assert
(
BlockSize
==
...
...
@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
if
(
DistributeThreadAlongColumnFirst
)
{
const
unsigned
cluster_work_block_id
=
const
index_t
cluster_work_block_id
=
thread_id
/
(
MThreadPerCluster
*
NThreadPerCluster
);
const
unsigned
thread_work_cluster_id
=
const
index_t
thread_work_cluster_id
=
thread_id
-
cluster_work_block_id
*
(
MThreadPerCluster
*
NThreadPerCluster
);
const
unsigned
m_cluster_work_block_id
=
cluster_work_block_id
/
NClusterWork
;
const
unsigned
n_cluster_work_block_id
=
const
index_t
m_cluster_work_block_id
=
cluster_work_block_id
/
NClusterWork
;
const
index_t
n_cluster_work_block_id
=
cluster_work_block_id
-
m_cluster_work_block_id
*
NClusterWork
;
const
unsigned
m_thread_work_cluster_id
=
thread_work_cluster_id
/
NThreadPerCluster
;
const
unsigned
n_thread_work_cluster_id
=
const
index_t
m_thread_work_cluster_id
=
thread_work_cluster_id
/
NThreadPerCluster
;
const
index_t
n_thread_work_cluster_id
=
thread_work_cluster_id
-
m_thread_work_cluster_id
*
NThreadPerCluster
;
#if 0
...
...
@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
m_in_c
,
unsigned
n_in_c
)
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
m_in_c
,
index_t
n_in_c
)
{
return
MatrixIndex
{
m_in_c
,
n_in_c
};
}
...
...
@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// a is transposed, b is not
constexpr
auto
a_thread_mtx
=
...
...
@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
...
...
@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
unsigned
MPerThreadSubC
,
unsigned
NPerThreadSubC
,
unsigned
MLevel0Cluster
,
unsigned
NLevel0Cluster
,
unsigned
MLevel1Cluster
,
unsigned
NLevel1Cluster
,
unsigned
KPerThreadLoop
>
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct
MatrixIndex
{
unsigned
row
;
unsigned
col
;
index_t
row
;
index_t
col
;
};
unsigned
mMyThreadOffsetA
;
unsigned
mMyThreadOffsetB
;
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
()
{
constexpr
unsigned
ThreadPerLevel1Cluster
=
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
...
...
@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
unsigned
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
unsigned
NPerLevel1Cluster
=
N
/
NRepeat
;
constexpr
index_t
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
index_t
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
unsigned
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
...
...
@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB
=
b_block_mtx
.
Get1dIndex
(
0
,
c_thread_mtx_index
.
col
);
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
unsigned
thread_id
)
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
unsigned
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
unsigned
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
unsigned
level1_m_id
=
level1_id
/
NLevel1Cluster
;
unsigned
level1_n_id
=
level1_id
%
NLevel1Cluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1Cluster
;
unsigned
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
unsigned
level0_m_id
=
level0_id
/
NLevel0Cluster
;
unsigned
level0_n_id
=
level0_id
%
NLevel0Cluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
unsigned
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
unsigned
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away if input is known
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
unsigned
m_in_c
,
unsigned
n_in_c
)
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
m_in_c
,
index_t
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
unsigned
m_repeat
=
m_in_c
/
MPerThreadSubC
;
unsigned
n_repeat
=
n_in_c
/
NPerThreadSubC
;
index_t
m_repeat
=
m_in_c
/
MPerThreadSubC
;
index_t
n_repeat
=
n_in_c
/
NPerThreadSubC
;
unsigned
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
unsigned
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
index_t
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
index_t
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
...
...
@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
...
...
@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// copy A-sub to form A
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
...
...
@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy B-sub to form B
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
...
...
@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
...
...
@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
//
#pragma unroll
// copy A-sub to form A
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
...
...
@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx
.
GetLengths
());
}
#pragma unroll
//
#pragma unroll
// copy B-sub to form B
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
...
...
@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False
,
p_c_thread
,
f_accum
);
#el
se
#el
if 0
// inline asm
static_assert
(
c_thread_mtx
.
NRow
()
==
8
&&
c_thread_mtx
.
NCol
()
==
8
,
"asm is only for 8x8"
);
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
// A is transposed
for
(
index_t
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
// A is transposed
{
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
const
index_t
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
for
(
index_t
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
const
index_t
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
index_t
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %8, %9
\n
\
...
...
@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
...
...
@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA
p_a_thread_1
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread_1
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// preload A, B
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
m_repeat
*
MPerLevel1Cluster
,
...
...
@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
// copy B-sub to form B
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
n_repeat
*
NPerLevel1Cluster
,
...
...
@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
bool
even_loop
=
true
;
#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
+
KPerThreadLoop
<
K
;
for
(
index_t
k_begin
=
0
;
k_begin
+
KPerThreadLoop
<
K
;
k_begin
+=
KPerThreadLoop
,
even_loop
=
!
even_loop
)
{
// loop over k
FloatA
*
p_a_thread_now
=
even_loop
?
p_a_thread_0
:
p_a_thread_1
;
...
...
@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// preload next A, B
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy A-sub to form A
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
mMyThreadOffsetA
+
...
...
@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
// copy B-sub to form B
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
mMyThreadOffsetB
+
...
...
@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
M
=
a_block_mtx
.
NCol
();
constexpr
unsigned
N
=
b_block_mtx
.
NCol
();
constexpr
unsigned
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A-sub, B-sub, C-sub
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
...
...
@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
#pragma unroll
// loop over k
for
(
unsigned
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
// C-sub(s) in first row-wise subblock of C
{
...
...
@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy next B-sub, and do GEMM
for
(
unsigned
n_repeat
=
1
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
1
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
...
...
@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
for
(
unsigned
m_repeat
=
1
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
1
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
// copy a A-sub
threadwise_matrix_copy
(
...
...
@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx
.
GetLengths
());
// do some GEMMs
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_gemm
(
a_thread_sub_mtx
,
...
...
src/include/common.hip.hpp
View file @
766b0a9e
...
...
@@ -5,9 +5,9 @@
#include "Array.hip.hpp"
#include "functional.hip.hpp"
__device__
unsigned
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
unsigned
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
template
<
class
T1
,
class
T2
>
struct
is_same
...
...
@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
}
#endif
__host__
__device__
constexpr
unsigned
integer_divide_ceil
(
unsigned
a
,
unsigned
b
)
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
src/include/config.h.in
View file @
766b0a9e
...
...
@@ -11,3 +11,5 @@
#include "nvToolsExt.h"
#include "helper_cuda.h"
#endif
using index_t = uint32_t;
src/include/constant_integral.hip.hpp
View file @
766b0a9e
...
...
@@ -8,5 +8,5 @@ struct integral_constant
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
};
template
<
unsigned
N
>
using
Number
=
integral_constant
<
unsigned
,
N
>
;
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
src/include/data_type.hip.hpp
View file @
766b0a9e
#pragma once
#include "config.h"
template
<
class
T
,
unsigned
N
>
template
<
class
T
,
index_t
N
>
struct
vector_type
{
};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment