Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6b165b9b
Commit
6b165b9b
authored
Jul 31, 2020
by
Chao Liu
Browse files
refactor
parent
7f9e59a1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
61 deletions
+54
-61
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+50
-57
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+2
-2
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+2
-2
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
6b165b9b
...
@@ -167,9 +167,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -167,9 +167,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
//\todo static_assert for global vector load/store
// statc_assert();
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
...
@@ -179,6 +176,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -179,6 +176,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
constexpr
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
constexpr
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
constexpr
index_t
HTilda
=
constexpr
index_t
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
1
),
ConvStrideH
);
constexpr
index_t
WTilda
=
constexpr
index_t
WTilda
=
...
@@ -198,10 +198,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -198,10 +198,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
HTildaSlice
=
iHTildaRight
-
iHTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
constexpr
index_t
WTildaSlice
=
iWTildaRight
-
iWTildaLeft
;
// A matrix: weight
// weight out-of-bound check can be skipped
// weight out-of-bound check can be skipped
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
constexpr
bool
wei_skip_out_of_bound_check
=
true
;
// weight tensor
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
make_tuple
(
PassThrough
<
K
>
{},
...
@@ -217,15 +217,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -217,15 +217,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix: output tensor
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
out_skip_out_of_bound_check
=
false
;
constexpr
bool
out_skip_out_of_bound_check
=
false
;
#else
#else
//\todo sometimes output tensor out-of-bound check can be skipped, find out all such
// situations
constexpr
bool
out_skip_out_of_bound_check
=
true
;
constexpr
bool
out_skip_out_of_bound_check
=
true
;
#endif
#endif
// output tensor
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
...
@@ -256,14 +275,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -256,14 +275,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix: input tensor
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
constexpr
bool
in_skip_out_of_bound_check
=
false
;
constexpr
bool
in_skip_out_of_bound_check
=
false
;
#else
#else
//\todo sometimes input out-of-bound check can be skipped, find out all such situations
constexpr
bool
in_skip_out_of_bound_check
=
true
;
constexpr
bool
in_skip_out_of_bound_check
=
true
;
#endif
#endif
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_tuple
(
...
@@ -306,53 +346,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -306,53 +346,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
// GEMM
constexpr
index_t
YDotSlice
=
(
iYTilda
+
1
)
*
YDot
<=
Y
?
YDot
:
Y
%
YDot
;
constexpr
index_t
XDotSlice
=
(
iXTilda
+
1
)
*
XDot
<=
X
?
XDot
:
X
%
XDot
;
// A matrix
constexpr
auto
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{},
Slice
<
Sequence
<
YTilda
,
XTilda
>
,
Sequence
<
iYTilda
,
iXTilda
>
,
Sequence
<
iYTilda
+
1
,
iXTilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
C
,
1
,
1
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B matrix
constexpr
auto
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HTildaSlice
>
{},
PassThrough
<
WTildaSlice
>
{},
Slice
<
Sequence
<
YDot
,
XDot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
YDotSlice
,
XDotSlice
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
YDotSlice
,
XDotSlice
>>
{},
Merge
<
Sequence
<
N
,
HTildaSlice
,
WTildaSlice
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C matrix
constexpr
auto
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
=
constexpr
auto
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc
,
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
6b165b9b
...
@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -172,7 +172,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
0
#elif
1
// cdata = 64, BlockSize = 256, 128x128x16
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
6b165b9b
...
@@ -190,7 +190,7 @@ int main(int argc, char* argv[])
...
@@ -190,7 +190,7 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
K
=
1024
;
...
@@ -247,7 +247,7 @@ int main(int argc, char* argv[])
...
@@ -247,7 +247,7 @@ int main(int argc, char* argv[])
#if 0
#if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif
1
#elif
0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment