Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
5ce19234
Commit
5ce19234
authored
Apr 19, 2019
by
Chao Liu
Browse files
added GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
parent
19f17df4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
134 additions
and
405 deletions
+134
-405
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+1
-1
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
+18
-280
driver/driver.hip.cpp
driver/driver.hip.cpp
+1
-1
src/include/blockwise_3d_tensor_op.hip.hpp
src/include/blockwise_3d_tensor_op.hip.hpp
+1
-1
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+1
-22
src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
+46
-88
src/include/threadwise_4d_tensor_op.hip.hpp
src/include/threadwise_4d_tensor_op.hip.hpp
+66
-12
No files found.
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
5ce19234
...
@@ -111,7 +111,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -111,7 +111,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif
0
#elif
1
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
64
;
...
...
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
View file @
5ce19234
...
@@ -62,208 +62,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -62,208 +62,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 0
#if 1
// for 3x3, 34x34, v1r1, Pascal
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif
0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr
index_t
NPerBlock
=
4
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
1
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 34x34, v1r1, Vega 20
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
4
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
256
;
#elif 0
// for 3x3, 56x56, v1, Pascal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 56x56, v1r2, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
1
;
constexpr
index_t
GemmDataPerReadB
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
// for 3x3, 28x28, v1r1, Pacal
constexpr
index_t
NPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
CPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
8
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#elif 1
// for 3x3, 28x28, v1r2, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
...
@@ -275,14 +77,6 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -275,14 +77,6 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
...
@@ -293,73 +87,16 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -293,73 +87,16 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
using
InBlockReorderSrcSubLengths_NCHW
=
Sequence
<
4
,
1
,
1
,
2
>
;
using
InBlockReorderSrcClusterLengths_NCHW
=
Sequence
<
4
,
8
,
2
,
2
>
;
constexpr
index_t
BlockSize
=
128
;
using
InBlockReorderMapThreadCluster2SrcCluster
=
Sequence
<
1
,
2
,
3
,
0
>
;
#elif 0
constexpr
index_t
InBlockReorderDataPerRead_W
=
2
;
// for 1x1, 28x28
constexpr
index_t
InBlockReorderDataPerWrite_N
=
4
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
4
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
using
WeiBlockCopyClusterLengths
=
Sequence
<
4
,
1
,
32
>
;
#elif 1
constexpr
index_t
WeiBlockCopyDataPerRead_C
=
4
;
// for 1x1, 14x14, Pascal
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
2
;
constexpr
index_t
NPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
InBlockCopy_ThreadPerDimC
=
8
;
constexpr
index_t
InBlockCopy_ThreadPerDimH
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimW
=
2
;
constexpr
index_t
InBlockCopy_ThreadPerDimN
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite
=
2
;
constexpr
index_t
BlockSize
=
128
;
#endif
#endif
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
...
@@ -398,13 +135,14 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -398,13 +135,14 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
,
GemmDataPerReadB
,
Sequence
<
InBlockCopy_ThreadPerDimN
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockCopy_ThreadPerDimC
,
InBlockReorderSrcClusterLengths_NCHW
,
InBlockCopy_ThreadPerDimH
,
InBlockReorderMapThreadCluster2SrcCluster
,
InBlockCopy_ThreadPerDimW
>
,
InBlockReorderDataPerRead_W
,
InBlockCopyDataPerRead
,
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead
,
WeiBlockCopyClusterLengths
,
OutThreadCopyDataPerWrite
>
{};
WeiBlockCopyDataPerRead_C
,
OutThreadCopyDataPerWrite_N
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
...
driver/driver.hip.cpp
View file @
5ce19234
...
@@ -673,7 +673,7 @@ int main(int argc, char* argv[])
...
@@ -673,7 +673,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif
1
#elif
0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 1
#elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
...
...
src/include/blockwise_3d_tensor_op.hip.hpp
View file @
5ce19234
...
@@ -231,7 +231,7 @@ struct Blockwise3dTensorCopy3
...
@@ -231,7 +231,7 @@ struct Blockwise3dTensorCopy3
}
}
}
}
__device__
constexpr
index_t
GetRegisterClipboardSize
()
const
__device__
static
constexpr
index_t
GetRegisterClipboardSize
()
{
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
5ce19234
...
@@ -761,7 +761,6 @@ struct Blockwise4dTensorCopyReorder1
...
@@ -761,7 +761,6 @@ struct Blockwise4dTensorCopyReorder1
}
}
};
};
#if 1
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
Float
,
class
Float
,
class
SrcDesc
,
class
SrcDesc
,
...
@@ -1070,36 +1069,17 @@ struct Blockwise4dTensorCopyReorder3
...
@@ -1070,36 +1069,17 @@ struct Blockwise4dTensorCopyReorder3
}
}
#endif
#endif
#if 1
threadwise_4d_tensor_copy_reorder_given_dst2src_v2
(
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
thread_tensor_desc
,
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
p_clipboard
+
clipboard_offset
,
DstDesc
{},
DstDesc
{},
p_dst
+
dst_offset
+
mDstMyThreadOffset
,
p_dst
+
dst_offset
+
mDstMyThreadOffset
,
thread_sub_tensor_lengths
,
thread_sub_tensor_lengths
,
MapDst2Src
{});
MapDst2Src
{});
#endif
}
}
}
}
}
}
}
}
#if 0
if(get_block_1d_id() == 0)
{
printf("tid %5u, "
"data: %f %f %f %f %f %f %f %f\n",
get_thread_local_1d_id(),
p_clipboard[0],
p_clipboard[1],
p_clipboard[2],
p_clipboard[3],
p_clipboard[4],
p_clipboard[5],
p_clipboard[6],
p_clipboard[7]);
}
#endif
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
...
@@ -1110,4 +1090,3 @@ struct Blockwise4dTensorCopyReorder3
...
@@ -1110,4 +1090,3 @@ struct Blockwise4dTensorCopyReorder3
RunStoreRegisterClipboard
(
p_clipboard
,
p_dst
);
RunStoreRegisterClipboard
(
p_clipboard
,
p_dst
);
}
}
};
};
#endif
src/include/gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp
View file @
5ce19234
...
@@ -33,10 +33,14 @@ template <index_t GridSize,
...
@@ -33,10 +33,14 @@ template <index_t GridSize,
index_t
GemmKPerThreadLoop
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
GemmDataPerReadB
,
class
InBlockCopyThreadPerDims
,
class
InBlockReorderSrcSubLengths_NCHW
,
index_t
InBlockCopyDataPerRead
,
class
InBlockReorderSrcClusterLengths_NCHW
,
index_t
WeiBlockCopyDataPerRead
,
class
InBlockReorderMapThreadCluster2SrcCluster
,
index_t
OutThreadCopyDataPerWrite
>
index_t
InBlockReorderDataPerRead_W
,
index_t
InBlockReorderDataPerWrite_N
,
class
WeiBlockCopyClusterLengths_KXC
,
index_t
WeiBlockCopyDataPerRead_C
,
index_t
OutThreadCopyDataPerWrite_N
>
struct
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
struct
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
...
@@ -101,8 +105,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -101,8 +105,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
mod_conv
::
max
(
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockReorderDataPerWrite_N
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
WeiBlockCopyDataPerRead_C
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
...
@@ -117,68 +123,38 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -117,68 +123,38 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
// blockwise copy
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
auto
map_chwn2nchw
=
Sequence
<
1
,
2
,
3
,
0
>
{};
auto
map_chwn2nchw
=
Sequence
<
1
,
2
,
3
,
0
>
{};
#if 0
const auto blockwise_in_copy_reorder =
Blockwise4dTensorCopyReorder1<BlockSize,
Float,
decltype(in_n_c_h_w_global_desc),
decltype(in_c_h_w_n_block_desc),
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
decltype(map_chwn2nchw)>{};
#else
auto
map_thread_cluster_2_src_cluster
=
Sequence
<
1
,
2
,
0
,
3
>
{};
const
auto
blockwise_in_copy_reorder
=
const
auto
blockwise_in_copy_reorder
=
Blockwise4dTensorCopyReorder3
<
BlockSize
,
Blockwise4dTensorCopyReorder3
<
BlockSize
,
Float
,
Float
,
decltype
(
in_n_c_h_w_global_desc
),
decltype
(
in_n_c_h_w_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
),
Sequence
<
NPerBlock
,
CPerBlock
,
HoPerBlock
,
WiPerBlock
>
,
Sequence
<
NPerBlock
,
CPerBlock
,
HoPerBlock
,
WiPerBlock
>
,
Sequence
<
4
,
1
,
1
,
2
>
,
InBlockReorderSrcSubLengths_NCHW
,
Sequence
<
4
,
8
,
2
,
2
>
,
InBlockReorderSrcClusterLengths_NCHW
,
decltype
(
map_chwn2nchw
),
decltype
(
map_chwn2nchw
),
decltype
(
map_thread_cluster_2_src_cluster
),
InBlockReorderMapThreadCluster2SrcCluster
,
2
,
InBlockReorderDataPerRead_W
,
4
>
{};
InBlockReorderDataPerWrite_N
>
{};
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("size %u\n", blockwise_in_copy_reorder.GetRegisterClipboardSize());
}
#endif
#endif
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
#if 0
Blockwise3dTensorCopy1<BlockSize,
Float,
decltype(wei_c_x_k_global_desc),
decltype(wei_c_x_k_block_desc),
decltype(wei_c_x_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{};
#else
Blockwise3dTensorCopy3
<
BlockSize
,
Blockwise3dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_c_x_k_global_desc
),
decltype
(
wei_c_x_k_global_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
Sequence
<
4
,
1
,
32
>
,
Sequence
<
4
,
1
,
32
>
,
WeiBlockCopyDataPerRead
>
{};
WeiBlockCopyDataPerRead_C
>
{};
#endif
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_x_k_block_desc
.
GetStride
(
I0
)
>
{});
Number
<
KPerBlock
>
{},
Number
<
wei_c_x_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
...
@@ -252,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -252,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{
{
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
{
#if 1
blockwise_in_copy_reorder
.
Run
(
p_in_global_block_offset
+
blockwise_in_copy_reorder
.
Run
(
p_in_global_block_offset
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
0
,
0
,
y
,
0
),
in_n_c_h_w_global_desc
.
Get1dIndex
(
0
,
0
,
y
,
0
),
p_in_block
);
p_in_block
);
...
@@ -259,6 +236,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -259,6 +236,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
p_wei_block
);
p_wei_block
);
#else
Float
p_in_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
+
in_n_c_h_w_global_desc
.
Get1dIndex
(
0
,
0
,
y
,
0
),
p_in_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
p_wei_clipboard
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_clipboard
,
p_wei_block
);
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_clipboard
,
p_in_block
);
#endif
__syncthreads
();
__syncthreads
();
...
@@ -274,42 +268,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -274,42 +268,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
}
}
}
}
// output: register to global mem,
// output: register to global mem,
#if 0
const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const index_t wo_thread = b_thread / NPerBlock;
const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread,
wo_block_data_begin + wo_thread,
n_block_data_begin + n_thread)] =
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
}
}
}
}
#elif
1
const
auto
c_thread_mtx_begin
=
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
@@ -356,7 +315,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
...
@@ -356,7 +315,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
wo_block_data_begin
+
wo_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
#endif
}
}
};
};
src/include/threadwise_4d_tensor_op.hip.hpp
View file @
5ce19234
...
@@ -44,7 +44,7 @@ template <class SrcData,
...
@@ -44,7 +44,7 @@ template <class SrcData,
class
SrcOpLengths
,
class
SrcOpLengths
,
class
MapDst2Src
,
class
MapDst2Src
,
class
F
>
class
F
>
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_
by_get_dst_from_
src
(
__device__
void
threadwise_4d_tensor_pointwise_operation_binary_reorder_
given_dst2
src
(
SrcDesc
,
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstDesc
,
...
@@ -82,9 +82,9 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
...
@@ -82,9 +82,9 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
const
index_t
bindex
=
const
index_t
bindex
=
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
dst_desc
.
Get1dIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
#if 1
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
#else
#if 0
if(get_block_1d_id() == 0)
if(get_block_1d_id() == 0)
{
{
printf("tid %5u, "
printf("tid %5u, "
...
@@ -126,17 +126,16 @@ template <class SrcData,
...
@@ -126,17 +126,16 @@ template <class SrcData,
class
DstDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
SrcOpLengths
,
class
MapDst2Src
>
class
MapDst2Src
>
__device__
void
__device__
void
threadwise_4d_tensor_copy_reorder_given_dst2src
(
SrcDesc
,
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
SrcOpLengths
,
MapDst2Src
)
MapDst2Src
)
{
{
auto
f_copy
=
[](
const
SrcData
&
src
,
DstData
&
dst
)
{
dst
=
static_cast
<
DstData
>
(
src
);
};
auto
f_copy
=
[](
const
SrcData
&
src
,
DstData
&
dst
)
{
dst
=
static_cast
<
DstData
>
(
src
);
};
threadwise_4d_tensor_pointwise_operation_binary_reorder_
by_get_dst_from_
src
(
threadwise_4d_tensor_pointwise_operation_binary_reorder_
given_dst2
src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
}
...
@@ -146,7 +145,7 @@ __device__ void threadwise_4d_tensor_copy(
...
@@ -146,7 +145,7 @@ __device__ void threadwise_4d_tensor_copy(
{
{
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_
by_get_dst_from_
src
(
threadwise_4d_tensor_copy_reorder_
given_dst2
src
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
dst_from_src_reorder
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
dst_from_src_reorder
);
}
}
...
@@ -212,6 +211,61 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
...
@@ -212,6 +211,61 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
}
}
}
}
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
threadwise_4d_tensor_copy_reorder_given_dst2src_v2
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
MapDst2Src
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
MapDst2Src
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
// ref_desc has dst_desc's ordering
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{}.
ReorderGivenNew2Old
(
MapDst2Src
{}));
for
(
index_t
did0
=
0
;
did0
<
ref_desc
.
GetLength
(
I0
);
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
ref_desc
.
GetLength
(
I1
);
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
ref_desc
.
GetLength
(
I2
);
++
did2
)
{
for
(
index_t
did3
=
0
;
did3
<
ref_desc
.
GetLength
(
I3
);
++
did3
)
{
const
auto
dst_multi_id
=
Array
<
index_t
,
4
>
{
did0
,
did1
,
did2
,
did3
};
const
auto
src_multi_id
=
reorder_array_given_old2new
(
dst_multi_id
,
MapDst2Src
{});
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
dst_multi_id
);
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
src_multi_id
);
p_dst
[
dst_index
]
=
p_src
[
src_index
];
}
}
}
}
}
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment