Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
8bd6ea1a
Commit
8bd6ea1a
authored
Jan 24, 2019
by
Chao Liu
Browse files
improve implicit gemm NCHW, SRCK, NKHW, and tuned
parent
1de6fd07
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
21 deletions
+59
-21
driver/conv.cu
driver/conv.cu
+11
-11
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
+17
-2
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+23
-5
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
...e/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+8
-3
No files found.
driver/conv.cu
View file @
8bd6ea1a
...
@@ -361,7 +361,7 @@ int main()
...
@@ -361,7 +361,7 @@ int main()
constexpr unsigned K = 1;
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned R = 3;
#elif
1
#elif
0
// 3x3, 34x34
// 3x3, 34x34
constexpr
unsigned
N
=
64
;
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
C
=
256
;
...
@@ -370,15 +370,6 @@ int main()
...
@@ -370,15 +370,6 @@ int main()
constexpr
unsigned
K
=
64
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
R
=
3
;
#elif 0
// 3x3, 54x54
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
64
;
constexpr
unsigned
HI
=
54
;
constexpr
unsigned
WI
=
54
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
#elif 0
#elif 0
// 3x3, 56x56
// 3x3, 56x56
constexpr
unsigned
N
=
64
;
constexpr
unsigned
N
=
64
;
...
@@ -415,6 +406,15 @@ int main()
...
@@ -415,6 +406,15 @@ int main()
constexpr
unsigned
K
=
64
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
7
;
constexpr
unsigned
S
=
7
;
constexpr
unsigned
R
=
7
;
constexpr
unsigned
R
=
7
;
#elif 1
// 3x3, 58x58
constexpr
unsigned
N
=
16
;
constexpr
unsigned
C
=
128
;
constexpr
unsigned
HI
=
58
;
constexpr
unsigned
WI
=
58
;
constexpr
unsigned
K
=
256
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
#endif
#endif
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
in_nchw_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
...
@@ -449,7 +449,7 @@ int main()
...
@@ -449,7 +449,7 @@ int main()
device_direct_convolution_2
device_direct_convolution_2
#elif 0
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr
device_implicit_gemm_convolution_1_nchw_kcsr
#elif
1
#elif
0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1
#elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn
device_implicit_gemm_convolution_1_chwn_csrk_khwn
...
...
driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh
View file @
8bd6ea1a
...
@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
...
@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr unsigned WoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 8;
constexpr unsigned BlockSize = 8;
#elif
1
#elif
0
// for 3x3, 34x34 | 3x3 58x58
// for 3x3, 34x34 | 3x3 58x58
, NKC = 64, 64, 256
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
NPerBlock
=
16
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
CPerBlock
=
4
;
...
@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
...
@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
// 3x3 58x58, NKC = 16,256,128
constexpr
unsigned
NPerBlock
=
8
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
4
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
#elif 0
// for 5x5, 36x36
// for 5x5, 36x36
...
...
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
8bd6ea1a
...
@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
...
@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned WoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr unsigned BlockSize = 16;
constexpr unsigned BlockSize = 16;
#elif
1
#elif
0
// for 3x3, 34x34
// for 3x3, 34x34
constexpr
unsigned
NPerBlock
=
1
;
constexpr
unsigned
NPerBlock
=
1
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
...
@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
...
@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
1
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
...
@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
...
@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr
unsigned
BlockSize
=
128
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
#elif 0
// for 3x3,
34x34
// for 3x3,
58x58
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
NPerBlock
=
4
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
8
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
// for 3x3, 56x56
constexpr
unsigned
NPerBlock
=
32
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
3
2
;
constexpr
unsigned
WoPerBlock
=
2
;
constexpr
unsigned
NPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
unsigned
BlockSize
=
128
;
...
@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
...
@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
CPerBlock
,
CPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
KPerThread
,
CPerThread
,
CPerThread
,
HoPerThread
,
HoPerThread
,
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
8bd6ea1a
...
@@ -17,6 +17,7 @@ template <unsigned GridSize,
...
@@ -17,6 +17,7 @@ template <unsigned GridSize,
unsigned
CPerBlock
,
unsigned
CPerBlock
,
unsigned
HoPerBlock
,
unsigned
HoPerBlock
,
unsigned
WoPerBlock
,
unsigned
WoPerBlock
,
unsigned
NPerThread
,
unsigned
KPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
CPerThread
,
unsigned
HoPerThread
,
unsigned
HoPerThread
,
...
@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
constexpr
unsigned
NPerThread
=
NPerBlock
;
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
ho_thread_data_begin
=
matrix_c_index
.
batch_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerThread
;
const
unsigned
wo_thread_data_begin
=
matrix_c_index
.
col_begin
/
NPerBlock
;
const
unsigned
n_thread_data_begin
=
matrix_c_index
.
col_begin
-
wo_thread_data_begin
*
NPerBlock
;
// output: register to global mem,
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
...
@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
out_hkwn_thread_desc
,
out_hkwn_thread_desc
,
p_out_thread
,
p_out_thread
,
out_nkhw_global_desc
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
p_out_global
+
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment