Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6904d163
Commit
6904d163
authored
Dec 25, 2019
by
Chao Liu
Browse files
debugging bwd data v2r1
parent
3f81301f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
95 deletions
+63
-95
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+17
-60
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+6
-11
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+40
-24
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
6904d163
...
@@ -114,55 +114,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -114,55 +114,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0 // debug
#if 0 // debug
// output tensor
constexpr index_t HtildaLeft = 0;
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
constexpr index_t WtildaLeft = 0;
out_n_k_ho_wo_global_desc,
constexpr index_t HtildaRight = Htilda;
make_tuple(PassThrough<N>{},
constexpr index_t WtildaRight = Wtilda;
PassThrough<K>{},
#else
// doesn't produce correct result for stride=2 dilation=3
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
#if 1
constexpr
index_t
HtildaLeft
=
constexpr
index_t
HtildaLeft
=
math
::
integer_divide_floor
(
InLeftPads
{}[
0
],
ConvStrides
{}[
0
]);
math
::
integer_divide_floor
(
InLeftPads
{}[
0
],
ConvStrides
{}[
0
]);
constexpr
index_t
WtildaLeft
=
constexpr
index_t
WtildaLeft
=
...
@@ -176,11 +132,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -176,11 +132,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
-
ConvDilations
{}[
1
]
*
(
Xtilda
-
1
),
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
-
ConvDilations
{}[
1
]
*
(
Xtilda
-
1
),
ConvStrides
{}[
1
])
+
ConvStrides
{}[
1
])
+
1
;
1
;
#else
constexpr
index_t
HtildaLeft
=
0
;
constexpr
index_t
WtildaLeft
=
0
;
constexpr
index_t
HtildaRight
=
Htilda
;
constexpr
index_t
WtildaRight
=
Wtilda
;
#endif
#endif
constexpr
index_t
HtildaTrim
=
HtildaRight
-
HtildaLeft
;
constexpr
index_t
HtildaTrim
=
HtildaRight
-
HtildaLeft
;
...
@@ -222,12 +173,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -222,12 +173,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else
// doesn't produce correct result for stride=2 dilation=1
constexpr
bool
in_skip_all_out_of_bound_check
=
true
;
#endif
// input tensor
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
true
>
{}),
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
,
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
@@ -241,11 +199,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -241,11 +199,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
Embed
<
Hip
,
Embed
<
Hip
,
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>
,
true
>
{},
in_skip_all_out_of_bound_check
>
{},
Embed
<
Wip
,
Embed
<
Wip
,
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>
,
true
>
{}),
in_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -270,7 +228,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -270,7 +228,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
Merge
<
Sequence
<
N
,
HtildaTrim
,
WtildaTrim
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#endif
// GEMM
// GEMM
constexpr
auto
gridwise_gemm
=
constexpr
auto
gridwise_gemm
=
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
6904d163
...
@@ -115,7 +115,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -115,7 +115,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
0
#elif
1
// BlockSize = 256, each thread hold 64 data
// BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input
// for 1x1 weight, 8x8 input
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -161,10 +161,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -161,10 +161,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
Wtilda
=
Wo
+
(
ConvDilationW
/
hcf_stride_dilation_w
)
*
(
X
-
Xtilda
);
constexpr
index_t
Wtilda
=
Wo
+
(
ConvDilationW
/
hcf_stride_dilation_w
)
*
(
X
-
Xtilda
);
#if 0 // debug
#if 0 // debug
constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t HtildaLeft = 0;
constexpr index_t GemmN = N * Htilda * Wtilda;
constexpr index_t WtildaLeft = 0;
#else
constexpr index_t HtildaRight = Htilda;
#if 1
constexpr index_t WtildaRight = Wtilda;
#else
// doesn't produce correct result for stride=2 dilation=3
constexpr
index_t
HtildaLeft
=
math
::
integer_divide_floor
(
InLeftPads
{}[
0
],
ConvStrides
{}[
0
]);
constexpr
index_t
HtildaLeft
=
math
::
integer_divide_floor
(
InLeftPads
{}[
0
],
ConvStrides
{}[
0
]);
constexpr
index_t
WtildaLeft
=
math
::
integer_divide_floor
(
InLeftPads
{}[
1
],
ConvStrides
{}[
1
]);
constexpr
index_t
WtildaLeft
=
math
::
integer_divide_floor
(
InLeftPads
{}[
1
],
ConvStrides
{}[
1
]);
...
@@ -176,18 +177,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -176,18 +177,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
-
ConvDilations
{}[
1
]
*
(
Xtilda
-
1
),
math
::
integer_divide_ceil
(
InLeftPads
{}[
1
]
+
Wi
-
1
-
ConvDilations
{}[
1
]
*
(
Xtilda
-
1
),
ConvStrides
{}[
1
])
+
ConvStrides
{}[
1
])
+
1
;
1
;
#else
constexpr
index_t
HtildaLeft
=
0
;
constexpr
index_t
WtildaLeft
=
0
;
constexpr
index_t
HtildaRight
=
Htilda
;
constexpr
index_t
WtildaRight
=
Wtilda
;
#endif
#endif
constexpr
index_t
HtildaTrim
=
HtildaRight
-
HtildaLeft
;
constexpr
index_t
HtildaTrim
=
HtildaRight
-
HtildaLeft
;
constexpr
index_t
WtildaTrim
=
WtildaRight
-
WtildaLeft
;
constexpr
index_t
WtildaTrim
=
WtildaRight
-
WtildaLeft
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
GemmN
=
N
*
HtildaTrim
*
WtildaTrim
;
constexpr
index_t
GemmN
=
N
*
HtildaTrim
*
WtildaTrim
;
#endif
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
6904d163
...
@@ -21,13 +21,28 @@ int main(int argc, char* argv[])
...
@@ -21,13 +21,28 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if 0
#if 1
// 3x3 filter, 2x2 stride, 35x35 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 34x34
// 3x3, 34x34
constexpr index_t N =
64
;
constexpr
index_t
N
=
128
;
constexpr index_t C =
256
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr index_t K =
256
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -38,25 +53,26 @@ int main(int argc, char* argv[])
...
@@ -38,25 +53,26 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 3x3, 28x28
// 3x3, 28x28
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
using
RightPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
constexpr
index_t
N
=
256
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -68,10 +84,10 @@ int main(int argc, char* argv[])
...
@@ -68,10 +84,10 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -98,7 +114,7 @@ int main(int argc, char* argv[])
...
@@ -98,7 +114,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 1x1 filter, 28x28 image
// 1x1 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
...
@@ -113,10 +129,10 @@ int main(int argc, char* argv[])
...
@@ -113,10 +129,10 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 1x1 filter, 17x17 input
// 1x1 filter, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -128,7 +144,7 @@ int main(int argc, char* argv[])
...
@@ -128,7 +144,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 5x5 filter, 2x2 pad, 7x7 input
// 5x5 filter, 2x2 pad, 7x7 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
4
8
;
constexpr
index_t
C
=
12
8
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
...
@@ -143,10 +159,10 @@ int main(int argc, char* argv[])
...
@@ -143,10 +159,10 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
constexpr
index_t
X
=
7
;
...
@@ -155,13 +171,13 @@ int main(int argc, char* argv[])
...
@@ -155,13 +171,13 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif
0
#elif
1
// 7x1 filter, 3x0 pad, 17x17 input
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -173,10 +189,10 @@ int main(int argc, char* argv[])
...
@@ -173,10 +189,10 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment