Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
aa0199a3
Commit
aa0199a3
authored
Jan 14, 2019
by
Chao Liu
Browse files
adding implicit gemm
parent
dc60d169
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
109 deletions
+70
-109
driver/conv.cu
driver/conv.cu
+3
-3
driver/device_implicit_gemm_convolution.cuh
driver/device_implicit_gemm_convolution.cuh
+26
-56
src/include/gridwise_implicit_gemm_convolution.cuh
src/include/gridwise_implicit_gemm_convolution.cuh
+41
-50
No files found.
driver/conv.cu
View file @
aa0199a3
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "conv_common.cuh"
#include "conv_common.cuh"
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh"
#include "device_direct_convolution_2.cuh"
//
#include "device_implicit_gemm_convolution.cuh"
#include "device_implicit_gemm_convolution.cuh"
//#include "device_winograd_convolution.cuh"
//#include "device_winograd_convolution.cuh"
struct
GeneratorTensor_1
struct
GeneratorTensor_1
...
@@ -393,9 +393,9 @@ int main()
...
@@ -393,9 +393,9 @@ int main()
{
{
#if 0
#if 0
device_direct_convolution_1(in_desc, in, wei_desc, wei, out_desc, out_device);
device_direct_convolution_1(in_desc, in, wei_desc, wei, out_desc, out_device);
#elif
1
device_direct_convolution_2
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
#elif
0
#elif
0
device_direct_convolution_2
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
#elif 1
device_implicit_gemm_convolution
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
device_implicit_gemm_convolution
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
#elif 0
#elif 0
device_winograd_convolution
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
device_winograd_convolution
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
...
...
driver/device_implicit_gemm_convolution.cuh
View file @
aa0199a3
...
@@ -26,53 +26,24 @@ void device_implicit_gemm_convolution(
...
@@ -26,53 +26,24 @@ void device_implicit_gemm_convolution(
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
#if 1
#if 1
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
YPerBlock
=
1
;
constexpr
unsigned
XPerBlock
=
16
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
YPerBlock
=
1
;
constexpr
unsigned
XPerBlock
=
27
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
BlockSize
=
216
;
#elif 0
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
YPerBlock
=
1
;
constexpr
unsigned
XPerBlock
=
32
;
constexpr
unsigned
NPerThread
=
2
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
4
;
constexpr
unsigned
BlockSize
=
256
;
constexpr
unsigned
BlockSize
=
256
;
#endif
#endif
constexpr
unsigned
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
constexpr
unsigned
GridSize
=
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
(
OutTileSizeH
*
YPerBlock
))
*
(
out_desc
.
GetLength
(
I2
)
/
HoPerBlock
)
*
(
out_desc
.
GetLength
(
I3
)
/
WoPerBlock
);
(
out_desc
.
GetLength
(
I3
)
/
(
OutTileSizeW
*
XPerBlock
));
dim3
block_dim
(
BlockSize
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
dim3
grid_dim
(
GridSize
);
...
@@ -85,22 +56,21 @@ void device_implicit_gemm_convolution(
...
@@ -85,22 +56,21 @@ void device_implicit_gemm_convolution(
cudaEventCreate
(
&
start
);
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution
<
T
,
gridwise_implicit_gemm_convolution_nchw_kcsr
<
GridSize
,
InDesc
,
BlockSize
,
WeiDesc
,
T
,
OutDesc
,
InDesc
,
OutTileSizeH
,
WeiDesc
,
OutTileSizeW
,
OutDesc
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
CPerBlock
,
CPerBlock
,
YPerBlock
,
HoPerBlock
,
XPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
KPerThread
,
CPerThread
,
CPerThread
,
HoPerThread
,
BlockSize
,
WoPerThread
>
GridSize
>
<<<
grid_dim
,
block_dim
>>>
(
InDesc
{},
<<<
grid_dim
,
block_dim
>>>
(
InDesc
{},
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
WeiDesc
{},
WeiDesc
{},
...
...
src/include/gridwise_implicit_gemm_convolution.cuh
View file @
aa0199a3
...
@@ -35,9 +35,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -35,9 +35,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
Constant
<
bool
,
true
>
;
constexpr
auto
False
=
Constant
<
bool
,
false
>
;
constexpr
auto
in_nchw_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_nchw_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcsr_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
wei_kcsr_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
...
@@ -48,13 +45,20 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -48,13 +45,20 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
HiPerBlock
=
HoPerBlock
+
S
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
constexpr
unsigned
WiPerBlock
=
WoPerBlock
+
R
-
1
;
//
block
//
tensor view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_srck_block_desc
=
constexpr
auto
wei_srck_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
CPerBlock
,
KPerBlock
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
S
,
R
,
CPerBlock
,
KPerBlock
>
{});
// matrix view of blockwise input and weight in LDS
constexpr
auto
in_cxhwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
,
Number
<
HiPerBlock
*
WiPerBlock
*
NPerBlock
>
);
constexpr
auto
wei_srcxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
S
*
R
*
CPerBlock
>
,
Number
<
KPerBlock
>
);
// LDS
// LDS
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
in_block_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
unsigned
wei_block_size
=
wei_srck_block_desc
.
GetElementSpace
();
constexpr
unsigned
wei_block_size
=
wei_srck_block_desc
.
GetElementSpace
();
...
@@ -62,8 +66,38 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -62,8 +66,38 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_in_block
[
in_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
__shared__
Float
p_wei_block
[
wei_block_size
];
// thread
// a series of batched GEMM
constexpr
auto
out_hkwn_thread_desc
=
xxxxxx
();
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
constexpr
auto
a_block_mtx_desc
=
wei_srcxk_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_block_mtx_desc
=
in_cxhwn_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{});
auto
f_accum
=
(
auto
&
c
,
auto
&
v
)
{
c
+=
v
;
};
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
a_block_mtx_desc
,
b_block_mtx_desc
,
true
,
false
,
HoPerBlock
,
0
,
xxx_b_matrix_stride
,
HoPerThread
,
KPerThread
,
NPerThread
*
WoPerThread
,
CPerTrhead
,
decltype
(
f_accum
)
>
{};
// tensor view of threadwise output in register
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
// register
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
...
@@ -85,14 +119,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -85,14 +119,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
in_chwn_block_desc
,
in_chwn_block_desc
,
reorder_nchw2chwn
);
reorder_nchw2chwn
);
// matrix view of input
constexpr
unsigned
in_row
=
in_chwn_block_desc
.
GetLength
(
I0
);
constexpr
unsigned
in_col
=
in_chwn_block_desc
.
GetLength
(
I1
)
*
in_chwn_block_desc
.
GetLength
(
I2
)
*
in_chwn_block_desc
.
GetLength
(
I3
);
constexpr
auto
in_cxhwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
in_row
>
,
Number
<
in_col
>
,
Number
<
in_col
>
);
// weight: global mem to LDS,
// weight: global mem to LDS,
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
constexpr
auto
reorder_kcsr2srck
=
Sequence
<
3
,
2
,
0
,
1
>
{};
constexpr
auto
reorder_kcsr2srck
=
Sequence
<
3
,
2
,
0
,
1
>
{};
...
@@ -104,44 +130,8 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -104,44 +130,8 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
wei_csrk_block_desc
,
wei_csrk_block_desc
,
reorder_kcsr2csrk
);
reorder_kcsr2csrk
);
// matrix view of wei
constexpr
unsigned
wei_row
=
wei_srck_block_desc
.
GetLength
(
I0
)
*
wei_srck_block_desc
.
GetLength
(
I1
)
*
wei_srck_block_desc
.
GetLength
(
I2
);
constexpr
unsigned
wei_col
=
wei_srck_block_desc
.
GetLength
(
I3
);
constexpr
auto
wei_srcxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
wei_row
>
,
Number
<
wei_col
>
,
Number
<
wei_col
>
);
__syncthreads
();
__syncthreads
();
// a series of batched GEMM
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
constexpr
auto
a_block_mtx_desc
=
wei_srcxk_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{});
constexpr
auto
b_block_mtx_desc
=
in_cxhwn_block_mtx_desc
.
MakeSubMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{});
auto
f_accum
=
(
auto
&
c
,
auto
&
v
)
{
c
+=
v
;
};
const
auto
blockwise_batch_gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
a_block_mtx_desc
,
b_block_mtx_desc
,
true
,
false
,
HoPerBlock
,
0
,
xxx_b_matrix_stride
,
HoPerThread
,
KPerThread
,
NPerThread
*
WoPerThread
,
CPerTrhead
,
decltype
(
f_accum
)
>
{};
// loop over filter point
// loop over filter point
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
for
(
unsigned
s
=
0
;
s
<
S
;
++
s
)
{
{
...
@@ -165,6 +155,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
...
@@ -165,6 +155,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
// output: register to global mem,
// output: register to global mem,
// convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo]
// convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo]
constexpr
auto
reorder_hkwn2nkhw
=
Sequence
<
2
,
1
,
3
,
0
>
{};
constexpr
auto
reorder_hkwn2nkhw
=
Sequence
<
2
,
1
,
3
,
0
>
{};
threadwise_4d_tensor_copy_reorder
(
threadwise_4d_tensor_copy_reorder
(
out_hkwn_thread_desc
,
out_hkwn_thread_desc
,
p_out_thread
,
p_out_thread
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment