Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2f0f26d3
Commit
2f0f26d3
authored
Nov 20, 2019
by
Chao Liu
Browse files
adding bwd data
parent
d2490b49
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
606 additions
and
13 deletions
+606
-13
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v4r5_nchw_kcyx_nkhw_lds_double_buffer.hpp
+438
-0
driver/include/device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-1
driver/include/device_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp
+149
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+17
-11
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
2f0f26d3
...
@@ -22,8 +22,8 @@ template <index_t GridSize,
...
@@ -22,8 +22,8 @@ template <index_t GridSize,
typename
ConvDilations
,
typename
ConvDilations
,
typename
LeftPads
,
typename
LeftPads
,
typename
RightPads
,
typename
RightPads
,
index_t
EPerBlock
,
index_t
BPerBlock
,
index_t
BPerBlock
,
index_t
EPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
2f0f26d3
This diff is collapsed.
Click to expand it.
driver/include/device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
2f0f26d3
...
@@ -97,8 +97,8 @@ void device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc i
...
@@ -97,8 +97,8 @@ void device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc i
ConvDilations
,
ConvDilations
,
LeftPads
,
LeftPads
,
RightPads
,
RightPads
,
EPerBlock
,
BPerBlock
,
BPerBlock
,
EPerBlock
,
KPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
2f0f26d3
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
OutBlockCopySubLengths_K_B_N0
=
Sequence
<
1
,
1
,
4
>
;
using
OutBlockCopyClusterLengths_K_B_N0
=
Sequence
<
8
,
32
,
1
>
;
constexpr
index_t
OutBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
OutBlockCopyDstDataPerWrite_N0
=
4
;
using
WeiBlockCopySubLengths_K_E_C0
=
Sequence
<
1
,
4
,
1
>
;
using
WeiBlockCopyClusterLengths_K_E_C0
=
Sequence
<
8
,
8
,
4
>
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_C0
=
1
;
constexpr
index_t
InThreadCopyDstDataPerWrite_B
=
1
;
#endif
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
);
constexpr
index_t
GridSize
=
((
E
+
EPerBlock
-
1
)
/
EPerBlock
)
*
((
B
+
BPerBlock
-
1
)
/
BPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r5_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
BPerBlock
,
EPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
OutBlockCopySubLengths_K_B_N0
,
OutBlockCopyClusterLengths_K_B_N0
,
OutBlockCopySrcDataPerRead_B
,
OutBlockCopyDstDataPerWrite_N0
,
WeiBlockCopySubLengths_K_E_C0
,
WeiBlockCopyClusterLengths_K_E_C0
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_C0
,
InThreadCopyDstDataPerWrite_B
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
driver/src/conv_bwd_data_driver.cpp
View file @
2f0f26d3
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -344,7 +345,12 @@ int main(int argc, char* argv[])
...
@@ -344,7 +345,12 @@ int main(int argc, char* argv[])
#endif
#endif
}
}
device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
#if 0
device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v4r5_nchw_kcyx_nkhw
#endif
(
in_nchw_desc
,
in_nchw_device
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment