Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3f81301f
Commit
3f81301f
authored
Dec 24, 2019
by
Chao Liu
Browse files
tweaking bwd data
parent
f67adee3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
56 additions
and
49 deletions
+56
-49
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+11
-6
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+20
-2
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+2
-2
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+12
-28
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+11
-11
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
3f81301f
...
...
@@ -227,7 +227,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
true
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
@@ -236,11 +236,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
true
>
{},
Embed
<
Wip
,
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
true
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
3f81301f
...
...
@@ -48,7 +48,10 @@ struct PassThrough
};
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
,
typename
LeftPads
,
typename
RightPads
>
template
<
typename
LowerLengths
,
typename
LeftPads
,
typename
RightPads
,
bool
SkipIsValidCheck
=
false
>
struct
Pad
{
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
...
...
@@ -89,6 +92,12 @@ struct Pad
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
#if 1 // debug
if
(
SkipIsValidCheck
)
{
return
true
;
}
#endif
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
nDim
;
++
i
)
...
...
@@ -366,7 +375,10 @@ struct UnMerge
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template
<
index_t
LowerLength
,
typename
UpperLengths
,
typename
Coefficients
>
template
<
index_t
LowerLength
,
typename
UpperLengths
,
typename
Coefficients
,
bool
SkipIsValidCheck
=
false
>
struct
Embed
{
static
constexpr
index_t
nDimLow
=
1
;
...
...
@@ -418,6 +430,12 @@ struct Embed
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
#if 1 // debug
if
(
SkipIsValidCheck
)
{
return
true
;
}
#endif
bool
flag
=
true
;
index_t
ncorner
=
1
;
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
3f81301f
...
...
@@ -55,7 +55,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
...
...
@@ -115,7 +115,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr
index_t
BlockSize
=
256
;
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
3f81301f
...
...
@@ -49,8 +49,7 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 8x8 image
constexpr
index_t
N
=
256
;
...
...
@@ -142,27 +141,27 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif 0
// 1x7 filter,
23x23
input
// 1x7 filter,
0x3 pad, 17x17
input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
23
;
constexpr
index_t
WI
=
23
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
...
...
@@ -172,27 +171,12 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
...
...
driver/src/conv_driver.cpp
View file @
3f81301f
...
...
@@ -74,7 +74,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
...
...
@@ -328,35 +328,35 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif 0
//
7x1
filter,
3x0
pad, 17x17 input
//
1x7
filter,
0x3
pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 1
//
1x7
filter,
0x3
pad, 17x17 input
//
7x1
filter,
3x0
pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#endif
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment