Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
039b5e7e
Commit
039b5e7e
authored
Dec 19, 2019
by
Chao Liu
Browse files
tweaking
parent
e402e30b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
107 additions
and
67 deletions
+107
-67
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+20
-48
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+40
-6
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+4
-1
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+31
-0
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+2
-2
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+9
-9
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
039b5e7e
...
...
@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
#if 0 // debug
// sanity-check for vectorized memory load
// TODO: this logic may not be correct for bwd-data
static_assert(
...
...
@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
#endif
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
...
...
@@ -88,30 +90,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
right_pad_ho
=
(
ConvDilationH
/
hcf_stride_dilation_h
)
*
(
Y
-
Ytilda
);
constexpr
index_t
right_pad_wo
=
(
ConvDilationW
/
hcf_stride_dilation_w
)
*
(
X
-
Xtilda
);
constexpr
index_t
Htilda
=
Ho
+
right_pad_ho
;
constexpr
index_t
Wtilda
=
Wo
+
right_pad_wo
;
constexpr
index_t
Htilda
=
Ho
+
(
ConvDilationH
/
hcf_stride_dilation_h
)
*
(
Y
-
Ytilda
);
constexpr
index_t
Wtilda
=
Wo
+
(
ConvDilationW
/
hcf_stride_dilation_w
)
*
(
X
-
Xtilda
);
// weight tensor
constexpr
auto
wei_k_c_yp_xp_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Y
,
X
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
*
Ytilda
-
Y
,
Xdot
*
Xtilda
-
X
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y
p
_x
p
_global_desc
,
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Ydot
,
Ytilda
>
,
Embed
<
Y
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
ConvStrideH
/
hcf_stride_dilation_h
,
1
,
0
>>
{},
Embed
<
Sequence
<
Xdot
,
Xtilda
>
,
Embed
<
X
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf_stride_dilation_w
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
...
@@ -122,42 +113,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0 // debug
// output tensor
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho
p
_wo
p
_global_desc,
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed<Sequence<Ydot, Htilda>,
Embed
<
Ho
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<-
ConvDilationH
/
hcf_stride_dilation_h
,
1
,
0
>>
{},
Embed<Sequence<Xdot, Wtilda>,
Embed
<
Wo
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf_stride_dilation_w
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
#else
// output tensor
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<-
ConvDilationH
/
hcf_stride_dilation_h
,
1
,
0
>
,
false
>
{},
Embed
<
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf_stride_dilation_w
,
1
,
0
>
,
false
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
#endif
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
...
...
@@ -178,8 +146,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
Embed
<
Hi
+
InputLeftPads
::
At
(
0
)
+
InputRightPads
::
At
(
0
),
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wi
+
InputLeftPads
::
At
(
1
)
+
InputRightPads
::
At
(
1
),
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
039b5e7e
...
...
@@ -320,7 +320,7 @@ struct UnMerge
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template
<
typename
UpperLengths
,
typename
Coefficients
,
bool
IsAlwaysValidMapping
=
true
>
template
<
index_t
LowerLength
,
typename
UpperLengths
,
typename
Coefficients
>
struct
Embed
{
static
constexpr
index_t
nDimLow
=
1
;
...
...
@@ -345,8 +345,10 @@ struct Embed
{
LowerIndex
idx_low
(
Coefficients
{}[
nDimUp
]);
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low
(
0
)
+=
idx_up
[
idim
]
*
Coefficients
{}[
idim
];
});
for
(
index_t
i
=
0
;
i
<
nDimUp
;
++
i
)
{
idx_low
(
0
)
+=
idx_up
[
i
]
*
Coefficients
{}[
i
];
}
return
idx_low
;
}
...
...
@@ -358,8 +360,10 @@ struct Embed
{
LowerIndex
idx_low_diff
{
0
};
static_for
<
0
,
nDimUp
,
1
>
{}(
[
&
](
auto
idim
)
{
idx_low_diff
(
0
)
+=
idx_up_diff
[
idim
]
*
Coefficients
{}[
idim
];
});
for
(
index_t
i
=
0
;
i
<
nDimUp
;
++
i
)
{
idx_low_diff
(
0
)
+=
idx_up_diff
[
i
]
*
Coefficients
{}[
i
];
}
return
idx_low_diff
;
}
...
...
@@ -368,7 +372,37 @@ struct Embed
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
IsAlwaysValidMapping
;
bool
flag
=
true
;
index_t
ncorner
=
1
;
for
(
index_t
idim
=
0
;
idim
<
nDimUp
;
++
idim
)
{
ncorner
*=
2
;
}
// loop over each corner of the upper tensor
for
(
index_t
icorner
=
0
;
icorner
<
ncorner
;
++
icorner
)
{
// generate upper index for each corner
auto
idx_up
=
make_zero_array
<
index_t
,
nDimUp
>
();
index_t
itmp
=
icorner
;
for
(
index_t
idim
=
nDimUp
-
1
;
idim
>=
0
;
--
idim
)
{
idx_up
(
idim
)
=
itmp
%
2
==
0
?
0
:
UpperLengths
::
At
(
idim
)
-
1
;
itmp
/=
2
;
}
// calculate lower index
auto
idx_low
=
CalculateLowerIndex
(
idx_up
);
// judge if lower index is valid
flag
=
flag
&&
idx_low
[
0
]
>=
0
&&
idx_low
[
0
]
<
LowerLength
;
}
return
flag
;
}
};
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
039b5e7e
...
...
@@ -498,7 +498,10 @@ struct TransformedTensorDescriptor
constexpr
auto
tran
=
Transforms
{}.
At
(
itran
);
// check a indtransformation if it does not always has a valid mapping
if
(
!
tran
.
IsValidUpperIndexAlwaysMappedToValidLowerIndex
())
constexpr
bool
is_valid_up_always_mapped_to_valid_low
=
decltype
(
tran
)
::
IsValidUpperIndexAlwaysMappedToValidLowerIndex
();
if
(
!
is_valid_up_always_mapped_to_valid_low
)
{
constexpr
auto
low_dims_part
=
LowDimensionIds
{}.
At
(
itran
);
constexpr
auto
low_lengths_part
=
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
039b5e7e
...
...
@@ -81,6 +81,37 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
4
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#endif
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
039b5e7e
...
...
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256;
...
...
@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif
0
#elif
1
// BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image
constexpr
index_t
BlockSize
=
256
;
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
039b5e7e
...
...
@@ -13,20 +13,20 @@
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
#if
1
#if
0
constexpr index_t N = 128;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
3
5
;
constexpr
index_t
WI
=
3
5
;
constexpr
index_t
K
=
384
;
constexpr index_t C =
128
;
constexpr index_t HI = 5;
constexpr index_t WI = 5;
constexpr index_t K =
8
;
constexpr index_t Y = 3;
constexpr index_t X = 3;
...
...
@@ -50,7 +50,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 1x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
...
...
@@ -260,7 +260,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
...
...
driver/src/conv_driver.cpp
View file @
039b5e7e
...
...
@@ -44,7 +44,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
#elif
1
// 1x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment