Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c64f63d5
Commit
c64f63d5
authored
Jan 21, 2019
by
Chao Liu
Browse files
refactor
parent
20968472
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
50 deletions
+77
-50
driver/conv.cu
driver/conv.cu
+3
-3
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+58
-37
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+4
-1
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
...e/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+1
-1
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+11
-8
No files found.
driver/conv.cu
View file @
c64f63d5
...
...
@@ -9,7 +9,7 @@
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh"
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck
_nkhw
.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh"
...
...
@@ -418,8 +418,8 @@ int main()
device_direct_convolution_2
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr
#elif
0
device_implicit_gemm_convolution_1_nchw_srck
#elif
1
device_implicit_gemm_convolution_1_nchw_srck
_nkhw
#elif 1
device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0
...
...
driver/device_implicit_gemm_convolution_1_nchw_srck.cuh
→
driver/device_implicit_gemm_convolution_1_nchw_srck
_nkhw
.cuh
View file @
c64f63d5
#pragma once
#include "gridwise_implicit_gemm_convolution_1_nchw_srck.cuh"
#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include <unistd.h>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_1_nchw_srck
(
InDesc
,
void
device_implicit_gemm_convolution_1_nchw_srck
_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcsr
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
)
Tensor
<
T
>&
out_nkhw
,
unsigned
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -101,6 +103,19 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
#endif
...
...
@@ -113,40 +128,46 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution_1_nchw_srck
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_srck_desc
),
decltype
(
out_nkhw_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
<<<
grid_dim
,
block_dim
>>>
(
in_nchw_desc
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
wei_srck_desc
,
static_cast
<
T
*>
(
wei_srck_device_buf
.
GetDeviceBuffer
()),
out_nkhw_desc
,
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
for
(
unsigned
i
=
0
;
i
<
nrepeat
;
++
i
)
{
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_srck_desc
),
decltype
(
out_nkhw_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
KPerThread
,
CPerThread
,
HoPerThread
,
WoPerThread
>
<<<
grid_dim
,
block_dim
>>>
(
in_nchw_desc
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
wei_srck_desc
,
static_cast
<
T
*>
(
wei_srck_device_buf
.
GetDeviceBuffer
()),
out_nkhw_desc
,
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
usleep
(
10
);
}
checkCudaErrors
(
cudaGetLastError
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
...
...
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
c64f63d5
...
...
@@ -90,6 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
BPerBatch
=
32
;
constexpr
unsigned
BPerThread
=
4
;
constexpr
unsigned
KPerThread
=
16
;
constexpr
unsigned
CPerThread
=
1
;
...
...
@@ -134,7 +136,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
CPerBlock
,
BPerThread
,
KPerThread
,
CPerThread
>
CPerThread
,
BPerBatch
>
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
wei_srck_desc
,
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck.cuh
→
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck
_nkhw
.cuh
View file @
c64f63d5
...
...
@@ -22,7 +22,7 @@ template <unsigned GridSize,
unsigned
HoPerThread
,
unsigned
WoPerThread
>
__global__
void
gridwise_implicit_gemm_convolution_1_nchw_srck
(
InGlobalDesc
,
gridwise_implicit_gemm_convolution_1_nchw_srck
_nkhw
(
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
WeiGlobalDesc
,
Float
*
const
__restrict__
p_wei_global
,
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
c64f63d5
...
...
@@ -19,7 +19,8 @@ template <unsigned GridSize,
unsigned
CPerBlock
,
unsigned
BPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
>
unsigned
CPerThread
,
unsigned
BPerBatch
>
__global__
void
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw
(
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
...
...
@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
const
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{});
// constexpr doesn't compile
static_assert
(
BPerBlock
%
BPerBatch
==
0
&&
BPerBatch
%
BPerThread
==
0
,
"B cannot be evenly divided
\n
"
);
const
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerB
lock
>
{},
Number
<
BPerB
atch
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
// constexpr doesn't compile
const
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
// constexpr doesn't compile
const
auto
blockwise_gemm
=
const
auto
blockwise_
batched_
gemm
=
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
...
...
@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false
,
false
,
0
,
BPerBatch
,
0
,
0
,
1
,
BPerBlock
/
BPerBatch
,
1
,
CPerThread
,
true
>
{};
...
...
@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{
auto
f_accum
=
[](
auto
&
c
,
const
auto
&&
ab
)
{
c
+=
ab
;
};
blockwise_gemm
.
run
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
blockwise_
batched_
gemm
.
run
(
p_wei_block
+
wei_srck_block_desc
.
Get1dIndex
(
s
,
r
,
0
,
0
),
p_in_block
+
s
*
Wi
+
r
,
p_out_thread
,
f_accum
);
...
...
@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
// output: register to global mem,
const
auto
matrix_c_index
=
blockwise_gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
blockwise_
batched_
gemm
.
CalculateThreadMatrixCIndex
(
get_thread_local_1d_id
());
const
unsigned
k_thread_data_begin
=
matrix_c_index
.
row_begin
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
col_begin
;
const
unsigned
b_thread_data_begin
=
matrix_c_index
.
batch_begin
*
BPerBatch
+
matrix_c_index
.
col_begin
;
const
unsigned
k_data_begin
=
k_block_data_begin
+
k_thread_data_begin
;
const
unsigned
b_data_begin
=
b_block_data_begin
+
b_thread_data_begin
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment