Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
766b0a9e
Commit
766b0a9e
authored
Mar 24, 2019
by
Chao Liu
Browse files
experimenting
parent
f35c64eb
Changes
33
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1413 additions
and
1376 deletions
+1413
-1376
driver/device_direct_convolution_1.hpp
driver/device_direct_convolution_1.hpp
+14
-14
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
+35
-35
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+80
-80
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+193
-193
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ice_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+154
-154
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+128
-98
driver/driver.hip.cpp
driver/driver.hip.cpp
+185
-185
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+4
-4
src/include/ConstantMatrixDescriptor.hip.hpp
src/include/ConstantMatrixDescriptor.hip.hpp
+14
-10
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+40
-36
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+17
-17
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+126
-128
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+105
-105
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+150
-150
src/include/blockwise_direct_convolution.hip.hpp
src/include/blockwise_direct_convolution.hip.hpp
+28
-28
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+132
-133
src/include/common.hip.hpp
src/include/common.hip.hpp
+3
-3
src/include/config.h.in
src/include/config.h.in
+2
-0
src/include/constant_integral.hip.hpp
src/include/constant_integral.hip.hpp
+2
-2
src/include/data_type.hip.hpp
src/include/data_type.hip.hpp
+1
-1
No files found.
driver/device_direct_convolution_1.hpp
View file @
766b0a9e
...
@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
...
@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const
Tensor
<
T
>&
wei
,
const
Tensor
<
T
>&
wei
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out
,
Tensor
<
T
>&
out
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
{
std
::
size_t
data_sz
=
sizeof
(
T
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
...
@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
...
@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1
#if 1
// 3x3, 34x34
// 3x3, 34x34
constexpr
unsigned
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
index_t
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
gridwise_direct_convolution_1
<
T
,
float
time
=
launch_kernel
(
gridwise_direct_convolution_1
<
T
,
InDesc
,
InDesc
,
...
...
driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp
View file @
766b0a9e
...
@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
...
@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const
Tensor
<
T
>&
wei
,
const
Tensor
<
T
>&
wei
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out
,
Tensor
<
T
>&
out
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
{
std
::
size_t
data_sz
=
sizeof
(
T
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
data_sz
*
in
.
mDesc
.
GetElementSpace
());
...
@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
...
@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1
#if 1
// 3x3, 34x34, 128 thread
// 3x3, 34x34, 128 thread
constexpr
unsigned
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
index_t
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
#elif 1
// 3x3, 34x34, 128 thread, fp16
// 3x3, 34x34, 128 thread, fp16
constexpr
unsigned
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
index_t
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_nchw_kcyx_nkhw
<
T
,
launch_kernel
(
gridwise_direct_convolution_2_nchw_kcyx_nkhw
<
T
,
...
...
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
View file @
766b0a9e
...
@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
...
@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const
Tensor
<
TInWei
>&
wei_kcyx
,
const
Tensor
<
TInWei
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
TOut
>&
out_nkhw
,
Tensor
<
TOut
>&
out_nkhw
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
{
// this suppose in / wei data type is int8x4
// this suppose in / wei data type is int8x4
constexpr
unsigned
NVector
=
4
;
constexpr
index_t
NVector
=
4
;
using
accum_t
=
int32_t
;
using
accum_t
=
int32_t
;
using
vector_t
=
vector_type
<
TInWei
,
NVector
>
;
using
vector_t
=
vector_type
<
TInWei
,
NVector
>
;
using
vector_mem_t
=
typename
vector_t
::
MemoryType
;
using
vector_mem_t
=
typename
vector_t
::
MemoryType
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
...
@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// vectorized input
// vectorized input
auto
in_nchw_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
/
NVector
,
Hi
,
Wi
>
{});
auto
in_nchw_vec_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
/
NVector
,
Hi
,
Wi
>
{});
...
@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
...
@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr
unsigned
NPerBlock = 2;
constexpr
index_t
NPerBlock = 2;
constexpr
unsigned
KPerBlock = 32;
constexpr
index_t
KPerBlock = 32;
constexpr
unsigned
CPerBlock = 4;
constexpr
index_t
CPerBlock = 4;
constexpr
unsigned
HoPerBlock = 2;
constexpr
index_t
HoPerBlock = 2;
constexpr
unsigned
WoPerBlock = 32;
constexpr
index_t
WoPerBlock = 32;
constexpr
unsigned
NPerThread = 2;
constexpr
index_t
NPerThread = 2;
constexpr
unsigned
KPerThread = 4;
constexpr
index_t
KPerThread = 4;
constexpr
unsigned
CPerThread = 2;
constexpr
index_t
CPerThread = 2;
constexpr
unsigned
HoPerThread = 2;
constexpr
index_t
HoPerThread = 2;
constexpr
unsigned
WoPerThread = 2;
constexpr
index_t
WoPerThread = 2;
constexpr
unsigned
InBlockCopyDataPerRead = 2;
constexpr
index_t
InBlockCopyDataPerRead = 2;
constexpr
unsigned
WeiBlockCopyDataPerRead = 2;
constexpr
index_t
WeiBlockCopyDataPerRead = 2;
constexpr
unsigned
BlockSize = 128;
constexpr
index_t
BlockSize = 128;
#elif
0
#elif
0
// 3x3, 34x34, 128 thread, fp32, vector = 2
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr
unsigned
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
index_t
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr
unsigned
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
1
;
constexpr
index_t
NPerThread
=
1
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
4
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr
unsigned
NPerBlock
=
1
;
constexpr
index_t
NPerBlock
=
1
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
16
;
constexpr
index_t
CPerBlock
=
16
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
1
;
constexpr
index_t
NPerThread
=
1
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
4
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
(
N
/
NPerBlock
)
*
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
float
time
=
launch_kernel
(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
<
TInWei
,
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
<
TInWei
,
...
...
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
View file @
766b0a9e
...
@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
...
@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
...
@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0
#if 0
// for 3x3, 34x34
// for 3x3, 34x34
constexpr
unsigned
NPerBlock = 16;
constexpr
index_t
NPerBlock = 16;
constexpr
unsigned
KPerBlock = 64;
constexpr
index_t
KPerBlock = 64;
constexpr
unsigned
CPerBlock = 4;
constexpr
index_t
CPerBlock = 4;
constexpr
unsigned
HoPerBlock = 2;
constexpr
index_t
HoPerBlock = 2;
constexpr
unsigned
WoPerBlock = 4;
constexpr
index_t
WoPerBlock = 4;
constexpr
unsigned
NPerThread = 8;
constexpr
index_t
NPerThread = 8;
constexpr
unsigned
KPerThread = 8;
constexpr
index_t
KPerThread = 8;
constexpr
unsigned
HoPerThread = 1;
constexpr
index_t
HoPerThread = 1;
constexpr
unsigned
WoPerThread = 1;
constexpr
index_t
WoPerThread = 1;
constexpr
unsigned
InBlockCopy_ThreadPerDimC = 4;
constexpr
index_t
InBlockCopy_ThreadPerDimC = 4;
constexpr
unsigned
InBlockCopy_ThreadPerDimH = 4;
constexpr
index_t
InBlockCopy_ThreadPerDimH = 4;
constexpr
unsigned
InBlockCopy_ThreadPerDimW = 2;
constexpr
index_t
InBlockCopy_ThreadPerDimW = 2;
constexpr
unsigned
InBlockCopy_ThreadPerDimN = 4;
constexpr
index_t
InBlockCopy_ThreadPerDimN = 4;
constexpr
unsigned
InBlockCopyDataPerRead = 4;
constexpr
index_t
InBlockCopyDataPerRead = 4;
constexpr
unsigned
WeiBlockCopyDataPerRead = 4;
constexpr
index_t
WeiBlockCopyDataPerRead = 4;
constexpr
unsigned
GemmMPerThreadSubC = 4;
constexpr
index_t
GemmMPerThreadSubC = 4;
constexpr
unsigned
GemmNPerThreadSubC = 4;
constexpr
index_t
GemmNPerThreadSubC = 4;
constexpr
unsigned
GemmMLevel0Cluster = 4;
constexpr
index_t
GemmMLevel0Cluster = 4;
constexpr
unsigned
GemmNLevel0Cluster = 2;
constexpr
index_t
GemmNLevel0Cluster = 2;
constexpr
unsigned
GemmMLevel1Cluster = 2;
constexpr
index_t
GemmMLevel1Cluster = 2;
constexpr
unsigned
GemmNLevel1Cluster = 4;
constexpr
index_t
GemmNLevel1Cluster = 4;
constexpr
unsigned
GemmKPerThreadLoop = 1;
constexpr
index_t
GemmKPerThreadLoop = 1;
constexpr
unsigned
OutThreadCopyDataPerWrite = 2;
constexpr
index_t
OutThreadCopyDataPerWrite = 2;
constexpr
unsigned
BlockSize = 128;
constexpr
index_t
BlockSize = 128;
#elif
0
#elif
0
// for 5x5, 36x36
// for 5x5, 36x36
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
8
;
constexpr
index_t
NPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
unsigned
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
// 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
InBlockCopyDataPerRead
=
2
;
// not used, yet
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
// not used, yet
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 3x3 58x58, NKC = 16,256,128
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 7x7, 38x38
// for 7x7, 38x38
constexpr
unsigned
NPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
1
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
// not used, yet
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
// not used, yet
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 3x3, 56x56
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 1x1, 28x28
// for 1x1, 28x28
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
unsigned
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
#elif 1
// for 1x1, 14x14
// for 1x1, 14x14
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
unsigned
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
unsigned
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn
<
GridSize
,
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn
<
GridSize
,
...
...
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
View file @
766b0a9e
...
@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
...
@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
LowerPads
,
LowerPads
,
UpperPads
,
UpperPads
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
...
@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
// reorder weight
// reorder weight
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
auto
wei_cyxk_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Y
,
X
,
K
>
{});
...
@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
...
@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
#if 0
constexpr
unsigned
NPerBlock = 1;
constexpr
index_t
NPerBlock = 1;
constexpr
unsigned
KPerBlock = 1;
constexpr
index_t
KPerBlock = 1;
constexpr
unsigned
CPerBlock = 1;
constexpr
index_t
CPerBlock = 1;
constexpr
unsigned
HoPerBlock = 2;
constexpr
index_t
HoPerBlock = 2;
constexpr
unsigned
WoPerBlock = 4;
constexpr
index_t
WoPerBlock = 4;
constexpr
unsigned
NPerThread = 1;
constexpr
index_t
NPerThread = 1;
constexpr
unsigned
KPerThread = 1;
constexpr
index_t
KPerThread = 1;
constexpr
unsigned
CPerThread = 1;
constexpr
index_t
CPerThread = 1;
constexpr
unsigned
HoPerThread = 1;
constexpr
index_t
HoPerThread = 1;
constexpr
unsigned
WoPerThread = 1;
constexpr
index_t
WoPerThread = 1;
constexpr
unsigned
WeiBlockCopyThreadPerDim0 = 1;
constexpr
index_t
WeiBlockCopyThreadPerDim0 = 1;
constexpr
unsigned
WeiBlockCopyThreadPerDim1 = 1;
constexpr
index_t
WeiBlockCopyThreadPerDim1 = 1;
constexpr
unsigned
BlockSize = 8;
constexpr
index_t
BlockSize = 8;
#elif
1
#elif
1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 3x3 58x58, NKC = 16,256,128
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 5x5, 36x36
// for 5x5, 36x36
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 7x7, 38x38
// for 7x7, 38x38
constexpr
unsigned
NPerBlock
=
8
;
constexpr
index_t
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 3x3, 56x56
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
index_t
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
2
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
2
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
64
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
64
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
1
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// for 1x1, 28x28
// for 1x1, 28x28
constexpr
unsigned
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
index_t
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
32
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
N
+
NPerBlock
-
1
)
/
NPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
)
*
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
((
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
)
*
((
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
float
time
=
launch_kernel
(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
<
GridSize
,
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
<
GridSize
,
...
...
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
766b0a9e
...
@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
unsigned
nrepeat
)
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
unsigned
N
=
in_nchw_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
in_nchw_desc
.
GetLength
(
I0
);
constexpr
unsigned
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
unsigned
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
unsigned
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
unsigned
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_desc
.
GetLength
(
I0
);
constexpr
unsigned
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLength
(
I1
);
constexpr
unsigned
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
);
constexpr
unsigned
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// convert in_nchw to in_cnhw
// convert in_nchw to in_cnhw
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
auto
in_chwn_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{});
...
@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0
#if 0
// 3x3, 34x34
// 3x3, 34x34
// need to use register double buffer for GEMM
// need to use register double buffer for GEMM
constexpr
unsigned
BPerBlock = 128;
constexpr
index_t
BPerBlock = 128;
constexpr
unsigned
KPerBlock = 64;
constexpr
index_t
KPerBlock = 64;
constexpr
unsigned
CPerBlock = 4;
constexpr
index_t
CPerBlock = 4;
constexpr
unsigned
BPerThread = 8;
constexpr
index_t
BPerThread = 8;
constexpr
unsigned
KPerThread = 8;
constexpr
index_t
KPerThread = 8;
constexpr
unsigned
GemmMPerThreadSubC = 4;
constexpr
index_t
GemmMPerThreadSubC = 4;
constexpr
unsigned
GemmNPerThreadSubC = 4;
constexpr
index_t
GemmNPerThreadSubC = 4;
constexpr
unsigned
GemmMLevel0Cluster = 4;
constexpr
index_t
GemmMLevel0Cluster = 4;
constexpr
unsigned
GemmNLevel0Cluster = 2;
constexpr
index_t
GemmNLevel0Cluster = 2;
constexpr
unsigned
GemmMLevel1Cluster = 2;
constexpr
index_t
GemmMLevel1Cluster = 2;
constexpr
unsigned
GemmNLevel1Cluster = 8;
constexpr
index_t
GemmNLevel1Cluster = 8;
constexpr
unsigned
GemmKPerThreadLoop = 1;
constexpr
index_t
GemmKPerThreadLoop = 1;
constexpr
unsigned
GemmThreadPerColumnPerCluster = 8;
constexpr
index_t
GemmThreadPerColumnPerCluster = 8;
constexpr
unsigned
GemmThreadPerRowPerCluster = 8;
constexpr
index_t
GemmThreadPerRowPerCluster = 8;
constexpr
unsigned
InBlockCopyThreadPerDim0 = 4;
constexpr
index_t
InBlockCopyThreadPerDim0 = 4;
constexpr
unsigned
InBlockCopyThreadPerDim1 = 16;
constexpr
index_t
InBlockCopyThreadPerDim1 = 16;
constexpr
unsigned
WeiBlockCopyThreadPerDim0 = 4;
constexpr
index_t
WeiBlockCopyThreadPerDim0 = 4;
constexpr
unsigned
WeiBlockCopyThreadPerDim1 = 16;
constexpr
index_t
WeiBlockCopyThreadPerDim1 = 16;
constexpr
unsigned
InBlockCopyDataPerRead = 4;
constexpr
index_t
InBlockCopyDataPerRead = 4;
constexpr
unsigned
WeiBlockCopyDataPerRead = 4;
constexpr
index_t
WeiBlockCopyDataPerRead = 4;
constexpr
unsigned
BlockSize = 128;
constexpr
index_t
BlockSize = 128;
#elif
0
#elif
0
// 1x1, 28x28, 64 threads
// 1x1, 28x28, 64 threads
constexpr
unsigned
BPerBlock
=
64
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
#elif
1
#elif
0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr
unsigned
BPerBlock
=
64
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 1x1, 28x28, 256 thread
// 1x1, 28x28, 256 thread
constexpr
unsigned
BPerBlock
=
128
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif 1
// 1x1, 14x14, Vega 10
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
unsigned
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
unsigned
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
unsigned
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
unsigned
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
unsigned
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
unsigned
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
unsigned
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
unsigned
GemmThreadPerColumnPerCluster
=
8
;
constexpr
index_t
GemmThreadPerColumnPerCluster
=
8
;
constexpr
unsigned
GemmThreadPerRowPerCluster
=
8
;
constexpr
index_t
GemmThreadPerRowPerCluster
=
8
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
InBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
InBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
index_t
WeiBlockCopyThreadPerDim0
=
4
;
constexpr
unsigned
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
index_t
WeiBlockCopyThreadPerDim1
=
16
;
constexpr
unsigned
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
unsigned
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
unsigned
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
unsigned
GridSize
=
constexpr
index_t
GridSize
=
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
((
N
*
Hi
*
Wi
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
...
@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
float
time
=
launch_kernel
(
float
time
=
launch_kernel
(
#if 1
#if 1
...
...
driver/driver.hip.cpp
View file @
766b0a9e
...
@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
...
@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template
<
class
...
Ts
>
template
<
class
...
Ts
>
double
operator
()(
Ts
...
Xs
)
const
double
operator
()(
Ts
...
Xs
)
const
{
{
std
::
array
<
unsigned
long
,
sizeof
...(
Ts
)
>
dims
=
{{
Xs
...}};
std
::
array
<
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
Xs
...}};
return
std
::
accumulate
(
dims
.
begin
(),
return
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
dims
.
end
(),
true
,
true
,
[](
bool
init
,
unsigned
long
x
)
->
int
{
return
init
!=
(
x
%
2
);
})
[](
bool
init
,
index_t
x
)
->
int
{
return
init
!=
(
x
%
2
);
})
?
1
?
1
:
-
1
;
:
-
1
;
}
}
...
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
...
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
std
::
initializer_list
<
unsigned
>
lengths
=
{
std
::
initializer_list
<
index_t
>
lengths
=
{
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I1
),
desc
.
GetLength
(
I2
),
desc
.
GetLength
(
I3
)};
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I1
),
desc
.
GetLength
(
I2
),
desc
.
GetLength
(
I3
)};
std
::
initializer_list
<
unsigned
>
strides
=
{
std
::
initializer_list
<
index_t
>
strides
=
{
desc
.
GetStride
(
I0
),
desc
.
GetStride
(
I1
),
desc
.
GetStride
(
I2
),
desc
.
GetStride
(
I3
)};
desc
.
GetStride
(
I0
),
desc
.
GetStride
(
I1
),
desc
.
GetStride
(
I2
),
desc
.
GetStride
(
I3
)};
return
TensorDescriptor
(
lengths
,
strides
);
return
TensorDescriptor
(
lengths
,
strides
);
...
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
...
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads
,
LowerPads
,
UpperPads
)
UpperPads
)
{
{
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
double
v
=
0
;
...
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
...
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std
::
size_t
HO
=
out_nkhw
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
HO
=
out_nkhw
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
WO
=
out_nkhw
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
WO
=
out_nkhw
.
mDesc
.
GetLengths
()[
3
];
unsigned
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
index_t
h_pad_low
=
LowerPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
index_t
w_pad_low
=
LowerPads
{}.
Get
(
Number
<
1
>
{});
unsigned
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
index_t
h_pad_up
=
UpperPads
{}.
Get
(
Number
<
0
>
{});
unsigned
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
index_t
w_pad_up
=
UpperPads
{}.
Get
(
Number
<
1
>
{});
std
::
size_t
HiPerTile
=
HoPerTile
+
Y
-
1
;
std
::
size_t
HiPerTile
=
HoPerTile
+
Y
-
1
;
std
::
size_t
WiPerTile
=
WoPerTile
+
X
-
1
;
std
::
size_t
WiPerTile
=
WoPerTile
+
X
-
1
;
...
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
#if 0
#if 0
constexpr
unsigned
N = 1;
constexpr
index_t
N = 1;
constexpr
unsigned
C = 1;
constexpr
index_t
C = 1;
constexpr
unsigned
HI = 28;
constexpr
index_t
HI = 28;
constexpr
unsigned
WI = 28;
constexpr
index_t
WI = 28;
constexpr
unsigned
K = 1;
constexpr
index_t
K = 1;
constexpr
unsigned
Y = 3;
constexpr
index_t
Y = 3;
constexpr
unsigned
X = 3;
constexpr
index_t
X = 3;
constexpr
unsigned
HPad = 0;
constexpr
index_t
HPad = 0;
constexpr
unsigned
WPad = 0;
constexpr
index_t
WPad = 0;
#elif
0
#elif
0
// 3x3, 34x34
// 3x3, 34x34
constexpr
unsigned
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
34
;
constexpr
index_t
HI
=
34
;
constexpr
unsigned
WI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
unsigned
K
=
64
;
constexpr
index_t
K
=
64
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 3x3, 56x56
// 3x3, 56x56
constexpr
unsigned
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
index_t
C
=
64
;
constexpr
unsigned
HI
=
56
;
constexpr
index_t
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
unsigned
K
=
64
;
constexpr
index_t
K
=
64
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
#elif 0
// 3x3, 58x58
// 3x3, 58x58
constexpr
unsigned
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
index_t
C
=
64
;
constexpr
unsigned
HI
=
58
;
constexpr
index_t
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
unsigned
K
=
64
;
constexpr
index_t
K
=
64
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
#elif 0
// 5x5, 36x36
// 5x5, 36x36
constexpr
unsigned
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
36
;
constexpr
index_t
HI
=
36
;
constexpr
unsigned
WI
=
36
;
constexpr
index_t
WI
=
36
;
constexpr
unsigned
K
=
64
;
constexpr
index_t
K
=
64
;
constexpr
unsigned
Y
=
5
;
constexpr
index_t
Y
=
5
;
constexpr
unsigned
X
=
5
;
constexpr
index_t
X
=
5
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 7x7, 38x38
// 7x7, 38x38
constexpr
unsigned
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
38
;
constexpr
index_t
HI
=
38
;
constexpr
unsigned
WI
=
38
;
constexpr
index_t
WI
=
38
;
constexpr
unsigned
K
=
64
;
constexpr
index_t
K
=
64
;
constexpr
unsigned
Y
=
7
;
constexpr
index_t
Y
=
7
;
constexpr
unsigned
X
=
7
;
constexpr
index_t
X
=
7
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 3x3, 58x58
// 3x3, 58x58
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
index_t
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
index_t
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
index_t
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
#elif 0
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
// 3x3 filter, 58x58 image, 0x0 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
index_t
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
index_t
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
index_t
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
index_t
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
// 3x3 filter, 56x56 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
index_t
C
=
128
;
constexpr
unsigned
HI
=
56
;
constexpr
index_t
HI
=
56
;
constexpr
unsigned
WI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
unsigned
K
=
256
;
constexpr
index_t
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
// 3x3 filter, 28x28 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
unsigned
K
=
512
;
constexpr
index_t
K
=
512
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
#elif 0
// 1x1 filter, 28x28 image
// 1x1 filter, 28x28 image
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
unsigned
K
=
512
;
constexpr
index_t
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
index_t
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
// 3x3 filter, 20x84 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
20
;
constexpr
index_t
HI
=
20
;
constexpr
unsigned
WI
=
84
;
constexpr
index_t
WI
=
84
;
constexpr
unsigned
K
=
256
;
constexpr
index_t
K
=
256
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
// 3x3 filter, 112x112 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
64
;
constexpr
index_t
C
=
64
;
constexpr
unsigned
HI
=
112
;
constexpr
index_t
HI
=
112
;
constexpr
unsigned
WI
=
112
;
constexpr
index_t
WI
=
112
;
constexpr
unsigned
K
=
128
;
constexpr
index_t
K
=
128
;
constexpr
unsigned
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
unsigned
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
unsigned
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
// 5x5 filter, 20x86 image, 1x1 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
20
;
constexpr
index_t
HI
=
20
;
constexpr
unsigned
WI
=
86
;
constexpr
index_t
WI
=
86
;
constexpr
unsigned
K
=
512
;
constexpr
index_t
K
=
512
;
constexpr
unsigned
Y
=
5
;
constexpr
index_t
Y
=
5
;
constexpr
unsigned
X
=
5
;
constexpr
index_t
X
=
5
;
constexpr
unsigned
HPad
=
1
;
constexpr
index_t
HPad
=
1
;
constexpr
unsigned
WPad
=
1
;
constexpr
index_t
WPad
=
1
;
#elif 0
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
// 5x5 filter, 28x28 image, 2x2 padding
constexpr
unsigned
N
=
16
;
constexpr
index_t
N
=
16
;
constexpr
unsigned
C
=
192
;
constexpr
index_t
C
=
192
;
constexpr
unsigned
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
unsigned
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
unsigned
K
=
32
;
constexpr
index_t
K
=
32
;
constexpr
unsigned
Y
=
5
;
constexpr
index_t
Y
=
5
;
constexpr
unsigned
X
=
5
;
constexpr
index_t
X
=
5
;
constexpr
unsigned
HPad
=
2
;
constexpr
index_t
HPad
=
2
;
constexpr
unsigned
WPad
=
2
;
constexpr
index_t
WPad
=
2
;
#elif 0
#elif 0
// 1x1 filter, 32x32 image
// 1x1 filter, 32x32 image
constexpr
unsigned
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
unsigned
HI
=
32
;
constexpr
index_t
HI
=
32
;
constexpr
unsigned
WI
=
32
;
constexpr
index_t
WI
=
32
;
constexpr
unsigned
K
=
512
;
constexpr
index_t
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
index_t
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
// 1x1 filter, 14x14 image
// 1x1 filter, 14x14 image
, C = 2048
constexpr
unsigned
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
unsigned
C
=
2048
;
constexpr
index_t
C
=
2048
;
constexpr
unsigned
HI
=
14
;
constexpr
index_t
HI
=
14
;
constexpr
unsigned
WI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
unsigned
K
=
512
;
constexpr
index_t
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
index_t
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 1
#elif 1
// 1x1 filter, 14x14 image, C = 512
// 1x1 filter, 14x14 image, C = 512
constexpr
unsigned
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
unsigned
C
=
512
;
constexpr
index_t
C
=
512
;
constexpr
unsigned
HI
=
14
;
constexpr
index_t
HI
=
14
;
constexpr
unsigned
WI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
unsigned
K
=
512
;
constexpr
index_t
K
=
512
;
constexpr
unsigned
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
unsigned
X
=
1
;
constexpr
index_t
X
=
1
;
constexpr
unsigned
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
unsigned
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#endif
#endif
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
auto
lower_pads
=
Sequence
<
HPad
,
WPad
>
{};
...
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
...
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
}
}
bool
do_verification
=
atoi
(
argv
[
1
]);
bool
do_verification
=
atoi
(
argv
[
1
]);
unsigned
nrepeat
=
atoi
(
argv
[
2
]);
index_t
nrepeat
=
atoi
(
argv
[
2
]);
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
src/include/Array.hip.hpp
View file @
766b0a9e
#pragma once
#pragma once
template
<
class
TData
,
unsigned
NSize
>
template
<
class
TData
,
index_t
NSize
>
struct
Array
struct
Array
{
{
using
Type
=
Array
<
TData
,
NSize
>
;
using
Type
=
Array
<
TData
,
NSize
>
;
static
constexpr
unsigned
nSize
=
NSize
;
static
constexpr
index_t
nSize
=
NSize
;
unsigned
mData
[
nSize
];
index_t
mData
[
nSize
];
template
<
class
...
Xs
>
template
<
class
...
Xs
>
__host__
__device__
Array
(
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
xs
)...}
__host__
__device__
Array
(
Xs
...
xs
)
:
mData
{
static_cast
<
TData
>
(
xs
)...}
{
{
}
}
__host__
__device__
TData
operator
[](
unsigned
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
TData
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
};
};
src/include/ConstantMatrixDescriptor.hip.hpp
View file @
766b0a9e
#pragma once
#pragma once
#include "common.hip.hpp"
#include "common.hip.hpp"
template
<
unsigned
NRow_
,
unsigned
NCol_
,
unsigned
RowStride_
>
template
<
index_t
NRow_
,
index_t
NCol_
,
index_t
RowStride_
>
struct
ConstantMatrixDescriptor
struct
ConstantMatrixDescriptor
{
{
__host__
__device__
constexpr
ConstantMatrixDescriptor
()
__host__
__device__
constexpr
ConstantMatrixDescriptor
()
...
@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
...
@@ -9,24 +9,28 @@ struct ConstantMatrixDescriptor
static_assert
(
NCol_
<=
RowStride_
,
"wrong! NCol > RowStride!"
);
static_assert
(
NCol_
<=
RowStride_
,
"wrong! NCol > RowStride!"
);
}
}
__host__
__device__
constexpr
unsigned
NRow
()
const
{
return
NRow_
;
}
__host__
__device__
constexpr
index_t
NRow
()
const
{
return
NRow_
;
}
__host__
__device__
constexpr
unsigned
NCol
()
const
{
return
NCol_
;
}
__host__
__device__
constexpr
index_t
NCol
()
const
{
return
NCol_
;
}
__host__
__device__
constexpr
unsigned
RowStride
()
const
{
return
RowStride_
;
}
__host__
__device__
constexpr
index_t
RowStride
()
const
{
return
RowStride_
;
}
__host__
__device__
constexpr
auto
GetLengths
()
const
{
return
Sequence
<
NRow_
,
NCol_
>
{};
}
__host__
__device__
constexpr
auto
GetLengths
()
const
{
return
Sequence
<
NRow_
,
NCol_
>
{};
}
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
{
return
NRow_
*
NCol_
;
}
__host__
__device__
constexpr
index_t
GetElementSize
()
const
{
return
NRow_
*
NCol_
;
}
__host__
__device__
constexpr
unsigned
GetElementSpace
()
const
{
return
NRow_
*
RowStride_
;
}
__host__
__device__
constexpr
index_t
GetElementSpace
()
const
{
return
NRow_
*
RowStride_
;
}
__host__
__device__
unsigned
Get1dIndex
(
unsigned
irow
,
unsigned
icol
)
const
__host__
__device__
index_t
Get1dIndex
(
index_t
irow
,
index_t
icol
)
const
{
{
#if DEVICE_BACKEND_HIP
return
__mul24
(
irow
,
RowStride_
)
+
icol
;
#else
return
irow
*
RowStride_
+
icol
;
return
irow
*
RowStride_
+
icol
;
#endif
}
}
template
<
unsigned
SubNRow
,
unsigned
SubNCol
>
template
<
index_t
SubNRow
,
index_t
SubNCol
>
__host__
__device__
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
__host__
__device__
constexpr
auto
MakeSubMatrixDescriptor
(
Number
<
SubNRow
>
,
Number
<
SubNCol
>
)
const
Number
<
SubNCol
>
)
const
{
{
...
@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
...
@@ -34,13 +38,13 @@ struct ConstantMatrixDescriptor
}
}
};
};
template
<
unsigned
NRow
,
unsigned
NCol
>
template
<
index_t
NRow
,
index_t
NCol
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
)
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
)
{
{
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
return
ConstantMatrixDescriptor
<
NRow
,
NCol
,
NCol
>
{};
}
}
template
<
unsigned
NRow
,
unsigned
NCol
,
unsigned
RowStride
>
template
<
index_t
NRow
,
index_t
NCol
,
index_t
RowStride
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
,
Number
<
RowStride
>
)
make_ConstantMatrixDescriptor
(
Number
<
NRow
>
,
Number
<
NCol
>
,
Number
<
RowStride
>
)
{
{
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
766b0a9e
...
@@ -2,35 +2,35 @@
...
@@ -2,35 +2,35 @@
#include "common.hip.hpp"
#include "common.hip.hpp"
// this is ugly, only for 2d
// this is ugly, only for 2d
template
<
unsigned
L0
,
unsigned
L1
>
template
<
index_t
L0
,
index_t
L1
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
>
)
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
>
)
{
{
return
Sequence
<
L1
,
1
>
{};
return
Sequence
<
L1
,
1
>
{};
}
}
// this is ugly, only for 4d
// this is ugly, only for 4d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
)
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
)
{
{
return
Sequence
<
L1
*
L2
*
L3
,
L2
*
L3
,
L3
,
1
>
{};
return
Sequence
<
L1
*
L2
*
L3
,
L2
*
L3
,
L3
,
1
>
{};
}
}
// this is ugly, only for 6d
// this is ugly, only for 6d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
,
unsigned
L4
,
unsigned
L5
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
L4
,
index_t
L5
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
>
)
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
>
)
{
{
return
Sequence
<
L1
*
L2
*
L3
*
L4
*
L5
,
L2
*
L3
*
L4
*
L5
,
L3
*
L4
*
L5
,
L4
*
L5
,
L5
,
1
>
{};
return
Sequence
<
L1
*
L2
*
L3
*
L4
*
L5
,
L2
*
L3
*
L4
*
L5
,
L3
*
L4
*
L5
,
L4
*
L5
,
L5
,
1
>
{};
}
}
// this is ugly, only for 8d
// this is ugly, only for 8d
template
<
unsigned
L0
,
template
<
index_t
L0
,
unsigned
L1
,
index_t
L1
,
unsigned
L2
,
index_t
L2
,
unsigned
L3
,
index_t
L3
,
unsigned
L4
,
index_t
L4
,
unsigned
L5
,
index_t
L5
,
unsigned
L6
,
index_t
L6
,
unsigned
L7
>
index_t
L7
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
,
L6
,
L7
>
)
calculate_default_strides
(
Sequence
<
L0
,
L1
,
L2
,
L3
,
L4
,
L5
,
L6
,
L7
>
)
{
{
...
@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
...
@@ -45,48 +45,48 @@ __host__ __device__ constexpr auto
}
}
// this is ugly, only for 2d
// this is ugly, only for 2d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
Align
>
template
<
index_t
L0
,
index_t
L1
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_default_strides_aligned
(
Sequence
<
L0
,
L1
>
,
__host__
__device__
constexpr
auto
calculate_default_strides_aligned
(
Sequence
<
L0
,
L1
>
,
Number
<
Align
>
)
Number
<
Align
>
)
{
{
constexpr
unsigned
L1_align
=
Align
*
((
L1
+
Align
-
1
)
/
Align
);
constexpr
index_t
L1_align
=
Align
*
((
L1
+
Align
-
1
)
/
Align
);
return
Sequence
<
L1_align
,
1
>
{};
return
Sequence
<
L1_align
,
1
>
{};
}
}
// this is ugly, only for 4d
// this is ugly, only for 4d
template
<
unsigned
L0
,
unsigned
L1
,
unsigned
L2
,
unsigned
L3
,
unsigned
Align
>
template
<
index_t
L0
,
index_t
L1
,
index_t
L2
,
index_t
L3
,
index_t
Align
>
__host__
__device__
constexpr
auto
calculate_default_strides_aligned
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
,
__host__
__device__
constexpr
auto
calculate_default_strides_aligned
(
Sequence
<
L0
,
L1
,
L2
,
L3
>
,
Number
<
Align
>
)
Number
<
Align
>
)
{
{
constexpr
unsigned
L3_align
=
Align
*
((
L3
+
Align
-
1
)
/
Align
);
constexpr
index_t
L3_align
=
Align
*
((
L3
+
Align
-
1
)
/
Align
);
return
Sequence
<
L1
*
L2
*
L3_align
,
L2
*
L3_align
,
L3_align
,
1
>
{};
return
Sequence
<
L1
*
L2
*
L3_align
,
L2
*
L3_align
,
L3_align
,
1
>
{};
}
}
template
<
class
Lengths
,
class
Strides
>
template
<
class
Lengths
,
class
Strides
>
struct
ConstantTensorDescriptor
struct
ConstantTensorDescriptor
{
{
using
Type
=
ConstantTensorDescriptor
<
Lengths
,
Strides
>
;
using
Type
=
ConstantTensorDescriptor
<
Lengths
,
Strides
>
;
static
constexpr
unsigned
nDim
=
Lengths
::
nDim
;
static
constexpr
index_t
nDim
=
Lengths
::
nDim
;
__host__
__device__
constexpr
ConstantTensorDescriptor
()
__host__
__device__
constexpr
ConstantTensorDescriptor
()
{
{
static_assert
(
Lengths
::
nDim
==
Strides
::
nDim
,
"nDim not consistent"
);
static_assert
(
Lengths
::
nDim
==
Strides
::
nDim
,
"nDim not consistent"
);
}
}
__host__
__device__
constexpr
unsigned
GetDimension
()
const
{
return
nDim
;
}
__host__
__device__
constexpr
index_t
GetDimension
()
const
{
return
nDim
;
}
__host__
__device__
constexpr
Lengths
GetLengths
()
const
{
return
Lengths
{};
}
__host__
__device__
constexpr
Lengths
GetLengths
()
const
{
return
Lengths
{};
}
__host__
__device__
constexpr
Strides
GetStrides
()
const
{
return
Strides
{};
}
__host__
__device__
constexpr
Strides
GetStrides
()
const
{
return
Strides
{};
}
template
<
unsigned
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
unsigned
GetLength
(
Number
<
I
>
)
const
__host__
__device__
constexpr
index_t
GetLength
(
Number
<
I
>
)
const
{
{
return
Lengths
{}.
Get
(
Number
<
I
>
{});
return
Lengths
{}.
Get
(
Number
<
I
>
{});
}
}
template
<
unsigned
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
unsigned
GetStride
(
Number
<
I
>
)
const
__host__
__device__
constexpr
index_t
GetStride
(
Number
<
I
>
)
const
{
{
return
Strides
{}.
Get
(
Number
<
I
>
{});
return
Strides
{}.
Get
(
Number
<
I
>
{});
}
}
...
@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
...
@@ -95,18 +95,18 @@ struct ConstantTensorDescriptor
struct
GetElementSize_f
struct
GetElementSize_f
{
{
template
<
class
IDim
>
template
<
class
IDim
>
__host__
__device__
constexpr
unsigned
operator
()(
IDim
idim
)
const
__host__
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
{
return
Type
{}.
GetLength
(
idim
);
return
Type
{}.
GetLength
(
idim
);
}
}
};
};
__host__
__device__
constexpr
unsigned
GetElementSize
()
const
__host__
__device__
constexpr
index_t
GetElementSize
()
const
{
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
multiply
struct
multiply
{
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
a
,
unsigned
b
)
const
__host__
__device__
constexpr
index_t
operator
()(
index_t
a
,
index_t
b
)
const
{
{
return
a
*
b
;
return
a
*
b
;
}
}
...
@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
...
@@ -119,19 +119,19 @@ struct ConstantTensorDescriptor
struct
GetElementSpace_f
struct
GetElementSpace_f
{
{
template
<
class
IDim
>
template
<
class
IDim
>
__host__
__device__
constexpr
unsigned
operator
()(
IDim
idim
)
const
__host__
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
{
return
(
Type
{}.
GetLength
(
idim
)
-
1
)
*
Type
{}.
GetStride
(
idim
);
return
(
Type
{}.
GetLength
(
idim
)
-
1
)
*
Type
{}.
GetStride
(
idim
);
}
}
};
};
template
<
class
Align
=
Number
<
1
>
>
template
<
class
Align
=
Number
<
1
>
>
__host__
__device__
constexpr
unsigned
GetElementSpace
(
Align
align
=
Align
{})
const
__host__
__device__
constexpr
index_t
GetElementSpace
(
Align
align
=
Align
{})
const
{
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
add
struct
add
{
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
a
,
unsigned
b
)
const
__host__
__device__
constexpr
index_t
operator
()(
index_t
a
,
index_t
b
)
const
{
{
return
a
+
b
;
return
a
+
b
;
}
}
...
@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
...
@@ -141,17 +141,21 @@ struct ConstantTensorDescriptor
}
}
template
<
class
...
Is
>
template
<
class
...
Is
>
__host__
__device__
unsigned
Get1dIndex
(
Is
...
is
)
const
__host__
__device__
index_t
Get1dIndex
(
Is
...
is
)
const
{
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
const
auto
multi_id
=
Array
<
unsigned
,
nDim
>
(
is
...);
const
auto
multi_id
=
Array
<
index_t
,
nDim
>
(
is
...);
unsigned
id
=
0
;
index_t
id
=
0
;
static_loop_n
<
nDim
>
{}([
&
](
auto
IDim
)
{
static_loop_n
<
nDim
>
{}([
&
](
auto
IDim
)
{
constexpr
unsigned
idim
=
IDim
.
Get
();
constexpr
index_t
idim
=
IDim
.
Get
();
#if DEVICE_BACKEND_HIP
id
+=
__mul24
(
multi_id
[
idim
],
GetStride
(
IDim
));
#else
id
+=
multi_id
[
idim
]
*
GetStride
(
IDim
);
id
+=
multi_id
[
idim
]
*
GetStride
(
IDim
);
#endif
});
});
return
id
;
return
id
;
...
@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
...
@@ -163,7 +167,7 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
}
}
template
<
unsigned
IDim
,
unsigned
NVector
>
template
<
index_t
IDim
,
index_t
NVector
>
__host__
__device__
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
NVector
>
)
const
__host__
__device__
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
NVector
>
)
const
{
{
assert
(
false
);
// not implemented
assert
(
false
);
// not implemented
...
@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
...
@@ -183,7 +187,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
return
ConstantTensorDescriptor
<
Lengths
,
Strides
>
{};
return
ConstantTensorDescriptor
<
Lengths
,
Strides
>
{};
}
}
template
<
class
Lengths
,
unsigned
Align
>
template
<
class
Lengths
,
index_t
Align
>
__host__
__device__
constexpr
auto
make_ConstantTensorDescriptor_aligned
(
Lengths
,
Number
<
Align
>
)
__host__
__device__
constexpr
auto
make_ConstantTensorDescriptor_aligned
(
Lengths
,
Number
<
Align
>
)
{
{
using
Strides
=
decltype
(
calculate_default_strides_aligned
(
Lengths
{},
Number
<
Align
>
{}));
using
Strides
=
decltype
(
calculate_default_strides_aligned
(
Lengths
{},
Number
<
Align
>
{}));
...
@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
...
@@ -193,8 +197,8 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths
template
<
class
TDesc
>
template
<
class
TDesc
>
__host__
__device__
void
print_ConstantTensorDescriptor
(
TDesc
,
const
char
*
s
)
__host__
__device__
void
print_ConstantTensorDescriptor
(
TDesc
,
const
char
*
s
)
{
{
constexpr
auto
desc
=
TDesc
{};
constexpr
auto
desc
=
TDesc
{};
constexpr
unsigned
ndim
=
desc
.
GetDimension
();
constexpr
index_t
ndim
=
desc
.
GetDimension
();
static_assert
(
ndim
>=
2
&&
ndim
<=
8
,
"wrong!"
);
static_assert
(
ndim
>=
2
&&
ndim
<=
8
,
"wrong!"
);
...
...
src/include/Sequence.hip.hpp
View file @
766b0a9e
...
@@ -2,38 +2,38 @@
...
@@ -2,38 +2,38 @@
#include "constant_integral.hip.hpp"
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
template
<
unsigned
...
Is
>
template
<
index_t
...
Is
>
struct
Sequence
struct
Sequence
{
{
using
Type
=
Sequence
<
Is
...
>
;
using
Type
=
Sequence
<
Is
...
>
;
static
constexpr
unsigned
nDim
=
sizeof
...(
Is
);
static
constexpr
index_t
nDim
=
sizeof
...(
Is
);
const
unsigned
mData
[
nDim
]
=
{
Is
...};
const
index_t
mData
[
nDim
]
=
{
Is
...};
template
<
unsigned
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
unsigned
Get
(
Number
<
I
>
)
const
__host__
__device__
constexpr
index_t
Get
(
Number
<
I
>
)
const
{
{
return
mData
[
I
];
return
mData
[
I
];
}
}
// this is ugly, only for nDIm = 4
// this is ugly, only for nDIm = 4
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
{
static_assert
(
nDim
==
4
,
"nDim != 4"
);
static_assert
(
nDim
==
4
,
"nDim != 4"
);
constexpr
auto
old_sequence
=
Type
{};
constexpr
auto
old_sequence
=
Type
{};
constexpr
unsigned
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
index_t
NR0
=
old_sequence
.
mData
[
I0
];
constexpr
unsigned
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
index_t
NR1
=
old_sequence
.
mData
[
I1
];
constexpr
unsigned
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
index_t
NR2
=
old_sequence
.
mData
[
I2
];
constexpr
unsigned
NR3
=
old_sequence
.
mData
[
I3
];
constexpr
index_t
NR3
=
old_sequence
.
mData
[
I3
];
return
Sequence
<
NR0
,
NR1
,
NR2
,
NR3
>
{};
return
Sequence
<
NR0
,
NR1
,
NR2
,
NR3
>
{};
}
}
template
<
unsigned
I0
,
unsigned
I1
,
unsigned
I2
,
unsigned
I3
>
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
ReorderByPutOldToNew
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
__host__
__device__
constexpr
auto
ReorderByPutOldToNew
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
{
{
// don't know how to implement this
// don't know how to implement this
...
@@ -41,7 +41,7 @@ struct Sequence
...
@@ -41,7 +41,7 @@ struct Sequence
assert
(
false
);
assert
(
false
);
}
}
template
<
unsigned
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
{
{
return
Sequence
<
Is
...,
I
>
{};
return
Sequence
<
Is
...,
I
>
{};
...
@@ -56,14 +56,14 @@ struct Sequence
...
@@ -56,14 +56,14 @@ struct Sequence
}
}
};
};
template
<
unsigned
...
Is
,
unsigned
I
>
template
<
index_t
...
Is
,
index_t
I
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
Is
...,
I
>
)
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
Is
...,
I
>
)
{
{
static_assert
(
sizeof
...(
Is
)
>=
1
,
"empty Sequence!"
);
static_assert
(
sizeof
...(
Is
)
>=
1
,
"empty Sequence!"
);
return
Sequence
<
Is
...
>
{};
return
Sequence
<
Is
...
>
{};
}
}
template
<
class
F
,
unsigned
...
Xs
,
unsigned
...
Ys
>
template
<
class
F
,
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
sequence_sequence_op
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
F
f
)
__host__
__device__
constexpr
auto
sequence_sequence_op
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
,
F
f
)
{
{
static_assert
(
Sequence
<
Xs
...
>::
nDim
==
Sequence
<
Ys
...
>::
nDim
,
"Dim not the same"
);
static_assert
(
Sequence
<
Xs
...
>::
nDim
==
Sequence
<
Ys
...
>::
nDim
,
"Dim not the same"
);
...
@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
...
@@ -71,12 +71,12 @@ __host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequenc
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
}
template
<
unsigned
...
Xs
,
unsigned
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
sequence_sequence_add
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
__host__
__device__
constexpr
auto
sequence_sequence_add
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
struct
add
struct
add
{
{
__host__
__device__
constexpr
unsigned
operator
()(
unsigned
x
,
unsigned
y
)
const
__host__
__device__
constexpr
index_t
operator
()(
index_t
x
,
index_t
y
)
const
{
{
return
x
+
y
;
return
x
+
y
;
}
}
...
@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
...
@@ -85,7 +85,7 @@ __host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequen
return
sequence_sequence_op
(
Sequence
<
Xs
...
>
{},
Sequence
<
Ys
...
>
{},
add
{});
return
sequence_sequence_op
(
Sequence
<
Xs
...
>
{},
Sequence
<
Ys
...
>
{},
add
{});
}
}
template
<
unsigned
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopBack
()
const
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopBack
()
const
{
{
return
sequence_pop_back
(
Type
{});
return
sequence_pop_back
(
Type
{});
...
...
src/include/blockwise_2d_tensor_op.hip.hpp
View file @
766b0a9e
This diff is collapsed.
Click to expand it.
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
766b0a9e
This diff is collapsed.
Click to expand it.
src/include/blockwise_batched_gemm.hip.hpp
View file @
766b0a9e
This diff is collapsed.
Click to expand it.
src/include/blockwise_direct_convolution.hip.hpp
View file @
766b0a9e
...
@@ -3,16 +3,16 @@
...
@@ -3,16 +3,16 @@
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
template
<
unsigned
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
Float
,
class
InBlockDesc
,
class
InBlockDesc
,
class
WeiBlockDesc
,
class
WeiBlockDesc
,
class
OutBlockDesc
,
class
OutBlockDesc
,
unsigned
NPerThread
,
index_t
NPerThread
,
unsigned
KPerThread
,
index_t
KPerThread
,
unsigned
CPerThread
,
index_t
CPerThread
,
unsigned
HoPerThread
,
index_t
HoPerThread
,
unsigned
WoPerThread
>
index_t
WoPerThread
>
__device__
void
blockwise_direct_convolution
(
InBlockDesc
,
__device__
void
blockwise_direct_convolution
(
InBlockDesc
,
Float
*
const
__restrict__
p_in_block
,
Float
*
const
__restrict__
p_in_block
,
WeiBlockDesc
,
WeiBlockDesc
,
...
@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -29,17 +29,17 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr
auto
wei_block_desc
=
WeiBlockDesc
{};
constexpr
auto
wei_block_desc
=
WeiBlockDesc
{};
constexpr
auto
out_block_desc
=
OutBlockDesc
{};
constexpr
auto
out_block_desc
=
OutBlockDesc
{};
constexpr
unsigned
Y
=
wei_block_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_block_desc
.
GetLength
(
I2
);
constexpr
unsigned
X
=
wei_block_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_block_desc
.
GetLength
(
I3
);
constexpr
unsigned
InTileSizeH
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
InTileSizeH
=
HoPerThread
+
Y
-
1
;
constexpr
unsigned
InTileSizeW
=
WoPerThread
+
X
-
1
;
constexpr
index_t
InTileSizeW
=
WoPerThread
+
X
-
1
;
// divide thread work
// divide thread work
constexpr
unsigned
NThreadWork
=
(
out_block_desc
.
GetLength
(
I0
)
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
NThreadWork
=
(
out_block_desc
.
GetLength
(
I0
)
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
unsigned
KThreadWork
=
(
out_block_desc
.
GetLength
(
I1
)
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
KThreadWork
=
(
out_block_desc
.
GetLength
(
I1
)
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
unsigned
YThreadWork
=
(
out_block_desc
.
GetLength
(
I2
)
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
YThreadWork
=
(
out_block_desc
.
GetLength
(
I2
)
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
unsigned
XThreadWork
=
(
out_block_desc
.
GetLength
(
I3
)
+
WoPerThread
-
1
)
/
WoPerThread
;
constexpr
index_t
XThreadWork
=
(
out_block_desc
.
GetLength
(
I3
)
+
WoPerThread
-
1
)
/
WoPerThread
;
#if 0
#if 0
if(threadIdx.x == 0)
if(threadIdx.x == 0)
...
@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -68,27 +68,27 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr
auto
out_thread_block_desc
=
constexpr
auto
out_thread_block_desc
=
make_ConstantTensorDescriptor
(
out_thread_desc
.
GetLengths
(),
out_block_desc
.
GetStrides
());
make_ConstantTensorDescriptor
(
out_thread_desc
.
GetLengths
(),
out_block_desc
.
GetStrides
());
const
unsigned
thread_id
=
threadIdx
.
x
;
const
index_t
thread_id
=
threadIdx
.
x
;
for
(
unsigned
thread_work_id
=
thread_id
;
for
(
index_t
thread_work_id
=
thread_id
;
thread_work_id
<
NThreadWork
*
KThreadWork
*
YThreadWork
*
XThreadWork
;
thread_work_id
<
NThreadWork
*
KThreadWork
*
YThreadWork
*
XThreadWork
;
thread_work_id
+=
BlockSize
)
thread_work_id
+=
BlockSize
)
{
{
unsigned
itmp
=
thread_work_id
;
index_t
itmp
=
thread_work_id
;
unsigned
n_thread_work_id
=
itmp
/
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
YThreadWork
*
XThreadWork
);
unsigned
k_thread_work_id
=
itmp
/
(
YThreadWork
*
XThreadWork
);
index_t
k_thread_work_id
=
itmp
/
(
YThreadWork
*
XThreadWork
);
itmp
-=
k_thread_work_id
*
(
YThreadWork
*
XThreadWork
);
itmp
-=
k_thread_work_id
*
(
YThreadWork
*
XThreadWork
);
unsigned
y_thread_work_id
=
itmp
/
XThreadWork
;
index_t
y_thread_work_id
=
itmp
/
XThreadWork
;
unsigned
x_thread_work_id
=
itmp
-
y_thread_work_id
*
XThreadWork
;
index_t
x_thread_work_id
=
itmp
-
y_thread_work_id
*
XThreadWork
;
unsigned
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
unsigned
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
unsigned
ho_thread_data_begin
=
y_thread_work_id
*
HoPerThread
;
index_t
ho_thread_data_begin
=
y_thread_work_id
*
HoPerThread
;
unsigned
wo_thread_data_begin
=
x_thread_work_id
*
WoPerThread
;
index_t
wo_thread_data_begin
=
x_thread_work_id
*
WoPerThread
;
unsigned
hi_thread_data_begin
=
ho_thread_data_begin
;
// minus padding
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
// minus padding
unsigned
wi_thread_data_begin
=
wo_thread_data_begin
;
// minus padding
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
// minus padding
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
...
@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -102,7 +102,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
p_out_thread
,
p_out_thread
,
out_thread_desc
.
GetLengths
());
out_thread_desc
.
GetLengths
());
for
(
unsigned
c_thread_data_begin
=
0
;
c_thread_data_begin
<
in_block_desc
.
GetLength
(
I1
);
for
(
index_t
c_thread_data_begin
=
0
;
c_thread_data_begin
<
in_block_desc
.
GetLength
(
I1
);
c_thread_data_begin
+=
CPerThread
)
c_thread_data_begin
+=
CPerThread
)
{
{
// threadwise convolution
// threadwise convolution
...
...
src/include/blockwise_gemm.hip.hpp
View file @
766b0a9e
This diff is collapsed.
Click to expand it.
src/include/common.hip.hpp
View file @
766b0a9e
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#include "Array.hip.hpp"
#include "Array.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
__device__
unsigned
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
unsigned
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
template
<
class
T1
,
class
T2
>
template
<
class
T1
,
class
T2
>
struct
is_same
struct
is_same
...
@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
...
@@ -35,7 +35,7 @@ __host__ __device__ constexpr T min(T a, T b)
}
}
#endif
#endif
__host__
__device__
constexpr
unsigned
integer_divide_ceil
(
unsigned
a
,
unsigned
b
)
__host__
__device__
constexpr
index_t
integer_divide_ceil
(
index_t
a
,
index_t
b
)
{
{
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
src/include/config.h.in
View file @
766b0a9e
...
@@ -11,3 +11,5 @@
...
@@ -11,3 +11,5 @@
#include "nvToolsExt.h"
#include "nvToolsExt.h"
#include "helper_cuda.h"
#include "helper_cuda.h"
#endif
#endif
using index_t = uint32_t;
src/include/constant_integral.hip.hpp
View file @
766b0a9e
...
@@ -8,5 +8,5 @@ struct integral_constant
...
@@ -8,5 +8,5 @@ struct integral_constant
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
__host__
__device__
constexpr
T
Get
()
const
{
return
value
;
}
};
};
template
<
unsigned
N
>
template
<
index_t
N
>
using
Number
=
integral_constant
<
unsigned
,
N
>
;
using
Number
=
integral_constant
<
index_t
,
N
>
;
src/include/data_type.hip.hpp
View file @
766b0a9e
#pragma once
#pragma once
#include "config.h"
#include "config.h"
template
<
class
T
,
unsigned
N
>
template
<
class
T
,
index_t
N
>
struct
vector_type
struct
vector_type
{
{
};
};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment