Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c6e3d607
Commit
c6e3d607
authored
Dec 22, 2019
by
Chao Liu
Browse files
tweaking bwd data
parent
bfba60cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
6 deletions
+84
-6
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+21
-4
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+46
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+17
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
c6e3d607
...
@@ -142,18 +142,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -142,18 +142,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
#if 0
constexpr index_t HtildaLeft = LeftPads{}[0] / ConvStrides{}[0];
constexpr idext_t HtildaRight = math::integer_divide_ceil
constexpr index_t WtidaTrimLeft = LeftPads{}[0] / ConvStrides{}[0];
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(PassThrough<N>{},
make_tuple(PassThrough<N>{},
PassThrough<C>{},
PassThrough<C>{},
Embed
<
Hi
+
InputLeftPads
::
At
(
0
)
+
InputRightPads
::
At
(
0
)
,
Trim<Sequence<Htilda, Wtilda>
,
Sequence<Ytilda, Htilda>,
Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed
<
Wi
+
InputLeftPads
::
At
(
1
)
+
InputRightPads
::
At
(
1
),
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#endif
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
c6e3d607
...
@@ -100,6 +100,52 @@ struct Pad
...
@@ -100,6 +100,52 @@ struct Pad
}
}
};
};
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
,
typename
LeftTrims
,
typename
RightTrims
>
struct
Trim
{
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
explicit
constexpr
Trim
()
{
static_assert
(
LowerLengths
::
GetSize
()
==
nDim
&&
LeftTrims
::
GetSize
()
==
nDim
&&
RightTrims
::
GetSize
()
==
nDim
,
"wrong! # of dimensions not consistent"
);
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfUpperDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
LowerLengths
{}
-
LeftTrims
{}
+
RightTrims
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
return
idx_up
+
LeftTrims
{};
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndexDiff
(
const
UpperIndex
&
idx_up_diff
,
const
UpperIndex
&
/* idx_up_old */
,
const
LowerIndex
&
/* idx_low_old */
)
{
return
idx_up_diff
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
};
// LowerLengths: Sequence<...>
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
>
template
<
typename
LowerLengths
>
struct
Merge
struct
Merge
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
c6e3d607
...
@@ -36,7 +36,7 @@ int main(int argc, char* argv[])
...
@@ -36,7 +36,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
1
#elif
0
// 3x3, 28x28
// 3x3, 28x28
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -141,6 +141,21 @@ int main(int argc, char* argv[])
...
@@ -141,6 +141,21 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif 1
// 1x7 filter, 23x23 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
23
;
constexpr
index_t
WI
=
23
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -156,7 +171,7 @@ int main(int argc, char* argv[])
...
@@ -156,7 +171,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
0
#elif
1
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
1024
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment