Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b8385cca
Commit
b8385cca
authored
Jan 15, 2020
by
Chao Liu
Browse files
change Trim to Slice
parent
0368045e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
77 additions
and
80 deletions
+77
-80
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+7
-7
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+21
-21
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+18
-18
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+19
-7
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+2
-2
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
...volution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
+2
-2
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-1
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+3
-3
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-19
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
b8385cca
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
namespace
ck
{
namespace
ck
{
// GemmM = C * Ytilda * Xtilda;
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
// GemmN = N * Htilda
NonZero
* Wtilda
NonZero
;
// GemmK = K * Ydot * Xdot;
// GemmK = K * Ydot * Xdot;
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -149,9 +149,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -149,9 +149,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Slice
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
Sequence
<
HtildaRight
,
WtildaRight
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
@@ -205,9 +205,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -205,9 +205,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Slice
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
Sequence
<
HtildaRight
,
WtildaRight
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
b8385cca
...
@@ -9,9 +9,9 @@
...
@@ -9,9 +9,9 @@
namespace
ck
{
namespace
ck
{
// Ytilda*Xtilda number of GEMMs
// Ytilda*Xtilda number of GEMMs
// GemmM = C
// GemmM = C
;
// GemmN = N * Htilda * Wtilda;
// GemmN = N * Htilda
NonZero
* Wtilda
NonZero
;
// GemmK = K *
slice(Ydot) * slice(Xdot)
;
// GemmK = K *
YdotNonZero * XdotNonZero
;
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
typename
Float
,
typename
Float
,
...
@@ -184,9 +184,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -184,9 +184,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Slice
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
Sequence
<
HtildaRight
,
WtildaRight
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
@@ -233,9 +233,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -233,9 +233,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Slice
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
Sequence
<
HtildaRight
,
WtildaRight
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
@@ -265,12 +265,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -265,12 +265,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Slice
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{},
Sequence
<
YdotNonZero
,
XdotNonZero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Slice
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
Sequence
<
ytilda
+
1
,
xtilda
+
1
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
@@ -291,9 +291,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -291,9 +291,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Slice
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{}),
Sequence
<
YdotNonZero
,
XdotNonZero
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
...
@@ -320,9 +320,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -320,9 +320,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Slice
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
Sequence
<
ytilda
+
1
,
xtilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
b8385cca
...
@@ -157,9 +157,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -157,9 +157,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Slice
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
Sequence
<
HtildaRight
,
WtildaRight
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
@@ -206,9 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -206,9 +206,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Ytilda
>
{},
PassThrough
<
Xtilda
>
{},
PassThrough
<
Xtilda
>
{},
Trim
<
Sequence
<
Htilda
,
Wtilda
>
,
Slice
<
Sequence
<
Htilda
,
Wtilda
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
HtildaLeft
,
WtildaLeft
>
,
Sequence
<
Htilda
-
HtildaRight
,
Wtilda
-
WtildaRight
>>
{}),
Sequence
<
HtildaRight
,
WtildaRight
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
make_tuple
(
...
@@ -227,12 +227,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -227,12 +227,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Slice
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{},
Sequence
<
YdotNonZero
,
XdotNonZero
>>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Slice
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
Sequence
<
ytilda
+
1
,
xtilda
+
1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
>
{},
Sequence
<
3
,
5
>
{}));
...
@@ -250,9 +250,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -250,9 +250,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ydot
,
Xdot
>
,
Slice
<
Sequence
<
Ydot
,
Xdot
>
,
Sequence
<
0
,
0
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
-
YdotNonZero
,
Xdot
-
XdotNonZero
>>
{}),
Sequence
<
YdotNonZero
,
XdotNonZero
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
make_tuple
(
...
@@ -272,9 +272,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -272,9 +272,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
HtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
PassThrough
<
WtildaTrim
>
{},
Trim
<
Sequence
<
Ytilda
,
Xtilda
>
,
Slice
<
Sequence
<
Ytilda
,
Xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
ytilda
,
xtilda
>
,
Sequence
<
Ytilda
-
ytilda
-
1
,
Xtilda
-
xtilda
-
1
>>
{}),
Sequence
<
ytilda
+
1
,
xtilda
+
1
>>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
,
4
>
{}),
make_tuple
(
make_tuple
(
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
b8385cca
...
@@ -110,19 +110,31 @@ struct Pad
...
@@ -110,19 +110,31 @@ struct Pad
};
};
// LowerLengths: Sequence<...>
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
,
typename
LeftTrims
,
typename
RightTrims
>
// SliceBegins: Sequence<...>
struct
Trim
// SliceEnds: Sequence<...>
template
<
typename
LowerLengths
,
typename
SliceBegins
,
typename
SliceEnds
>
struct
Slice
{
{
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
static
constexpr
index_t
nDim
=
LowerLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
LowerIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
using
UpperIndex
=
MultiIndex
<
nDim
>
;
__host__
__device__
explicit
constexpr
Trim
()
__host__
__device__
explicit
constexpr
Slice
()
{
{
static_assert
(
LowerLengths
::
GetSize
()
==
nDim
&&
LeftTrim
s
::
GetSize
()
==
nDim
&&
static_assert
(
LowerLengths
::
GetSize
()
==
nDim
&&
SliceBegin
s
::
GetSize
()
==
nDim
&&
RightTrim
s
::
GetSize
()
==
nDim
,
SliceEnd
s
::
GetSize
()
==
nDim
,
"wrong! # of dimensions not consistent"
);
"wrong! # of dimensions not consistent"
);
#if 0
// TODO: would not compile, error on constexpr
static_for<0, nDim, 1>{}([&](auto idim) {
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
SliceBegins::At(idim) >= 0 &&
SliceEnds::At(idim) <= LowerLengths::At(idim),
"wrong! Slice config is wrong");
});
#endif
}
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
__host__
__device__
static
constexpr
auto
GetNumOfLowerDimension
()
{
return
Number
<
nDim
>
{};
}
...
@@ -131,12 +143,12 @@ struct Trim
...
@@ -131,12 +143,12 @@ struct Trim
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
{
return
LowerLengths
{}
-
LeftTrims
{}
-
RightTrim
s
{};
return
SliceEnds
{}
-
SliceBegin
s
{};
}
}
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
__host__
__device__
static
constexpr
auto
CalculateLowerIndex
(
const
UpperIndex
&
idx_up
)
{
{
return
idx_up
+
LeftTrim
s
{};
return
idx_up
+
SliceBegin
s
{};
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
b8385cca
...
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
0
#if
1
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
0
#elif
1
// BlockSize = 64, each thread hold 64 data
// BlockSize = 64, each thread hold 64 data
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
View file @
b8385cca
...
@@ -46,7 +46,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
...
@@ -46,7 +46,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
@@ -83,7 +83,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
...
@@ -83,7 +83,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif
1
#elif
0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
b8385cca
...
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256, GemmKPerBlock = 8
// BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
b8385cca
...
@@ -158,10 +158,10 @@ int main(int argc, char* argv[])
...
@@ -158,10 +158,10 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif
0
#elif
1
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
...
@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
...
@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
1024
;
...
...
driver/src/conv_driver.cpp
View file @
b8385cca
...
@@ -87,21 +87,6 @@ int main(int argc, char* argv[])
...
@@ -87,21 +87,6 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
...
@@ -296,7 +281,7 @@ int main(int argc, char* argv[])
...
@@ -296,7 +281,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -327,7 +312,7 @@ int main(int argc, char* argv[])
...
@@ -327,7 +312,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif
1
#elif
0
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -439,7 +424,7 @@ int main(int argc, char* argv[])
...
@@ -439,7 +424,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -449,7 +434,7 @@ int main(int argc, char* argv[])
...
@@ -449,7 +434,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment