Commit ea484457 authored by Chao Liu's avatar Chao Liu
Browse files

tweaking bwd data v3r1

parent 98716c83
...@@ -137,6 +137,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -137,6 +137,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr bool wei_skip_all_out_of_bound_check = true;
// weight tensor // weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc, wei_k_c_y_x_global_desc,
...@@ -145,14 +147,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -145,14 +147,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Embed<Y, Embed<Y,
Sequence<Ydot, Ytilda>, Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>, Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
true>{}, wei_skip_all_out_of_bound_check>{},
Embed<X, Embed<X,
Sequence<Xdot, Xtilda>, Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>, Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
true>{}), wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
#if 1 // debug
constexpr bool out_skip_all_out_of_bound_check = false;
#else
constexpr bool out_skip_all_out_of_bound_check = true;
#endif
// output tensor // output tensor
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_ho_wo_global_desc, out_n_k_ho_wo_global_desc,
...@@ -161,11 +169,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw ...@@ -161,11 +169,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
Embed<Ho, Embed<Ho,
Sequence<Ydot, Htilda>, Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>, Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
false>{}, out_skip_all_out_of_bound_check>{},
Embed<Wo, Embed<Wo,
Sequence<Xdot, Wtilda>, Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>, Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
false>{}), out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
...@@ -22,7 +22,7 @@ int main(int argc, char* argv[]) ...@@ -22,7 +22,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
// 3x3 filter, 2x2 stride, 35x35 input // 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
...@@ -39,11 +39,11 @@ int main(int argc, char* argv[]) ...@@ -39,11 +39,11 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 128; constexpr index_t N = 64;
constexpr index_t C = 128; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
constexpr index_t K = 128; constexpr index_t K = 256;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -55,7 +55,7 @@ int main(int argc, char* argv[]) ...@@ -55,7 +55,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -65,15 +65,15 @@ int main(int argc, char* argv[]) ...@@ -65,15 +65,15 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t N = 128; constexpr index_t N = 256;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 8; constexpr index_t HI = 8;
constexpr index_t WI = 8; constexpr index_t WI = 8;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -85,7 +85,7 @@ int main(int argc, char* argv[]) ...@@ -85,7 +85,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 7; constexpr index_t HI = 7;
constexpr index_t WI = 7; constexpr index_t WI = 7;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -130,10 +130,10 @@ int main(int argc, char* argv[]) ...@@ -130,10 +130,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x1 filter, 17x17 input // 1x1 filter, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -145,10 +145,10 @@ int main(int argc, char* argv[]) ...@@ -145,10 +145,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 5x5 filter, 2x2 pad, 7x7 input // 5x5 filter, 2x2 pad, 7x7 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 7; constexpr index_t HI = 7;
constexpr index_t WI = 7; constexpr index_t WI = 7;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 5; constexpr index_t Y = 5;
constexpr index_t X = 5; constexpr index_t X = 5;
...@@ -157,10 +157,10 @@ int main(int argc, char* argv[]) ...@@ -157,10 +157,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<2, 2>; using LeftPads = Sequence<2, 2>;
using RightPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>;
#elif 1 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -175,10 +175,10 @@ int main(int argc, char* argv[]) ...@@ -175,10 +175,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 128; constexpr index_t K = 1024;
constexpr index_t Y = 7; constexpr index_t Y = 7;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -187,10 +187,10 @@ int main(int argc, char* argv[]) ...@@ -187,10 +187,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 1024;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 128; constexpr index_t K = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment