Commit 6904d163 authored by Chao Liu's avatar Chao Liu
Browse files

debugging bwd data v2r1

parent 3f81301f
...@@ -114,55 +114,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -114,55 +114,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug #if 0 // debug
// output tensor constexpr index_t HtildaLeft = 0;
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( constexpr index_t WtildaLeft = 0;
out_n_k_ho_wo_global_desc, constexpr index_t HtildaRight = Htilda;
make_tuple(PassThrough<N>{}, constexpr index_t WtildaRight = Wtilda;
PassThrough<K>{}, #else // doesn't produce correct result for stride=2 dilation=3
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
#if 1
constexpr index_t HtildaLeft = constexpr index_t HtildaLeft =
math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]); math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft = constexpr index_t WtildaLeft =
...@@ -176,11 +132,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -176,11 +132,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1), math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1),
ConvStrides{}[1]) + ConvStrides{}[1]) +
1; 1;
#else
constexpr index_t HtildaLeft = 0;
constexpr index_t WtildaLeft = 0;
constexpr index_t HtildaRight = Htilda;
constexpr index_t WtildaRight = Wtilda;
#endif #endif
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
...@@ -222,12 +173,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -222,12 +173,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug
constexpr bool in_skip_all_out_of_bound_check = false;
#else // doesn't produce correct result for stride=2 dilation=1
constexpr bool in_skip_all_out_of_bound_check = true;
#endif
// input tensor // input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(
PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, true>{}), Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
...@@ -241,11 +199,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -241,11 +199,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
Embed<Hip, Embed<Hip,
Sequence<Ytilda, Htilda>, Sequence<Ytilda, Htilda>,
Sequence<ConvDilationH, ConvStrideH, 0>, Sequence<ConvDilationH, ConvStrideH, 0>,
true>{}, in_skip_all_out_of_bound_check>{},
Embed<Wip, Embed<Wip,
Sequence<Xtilda, Wtilda>, Sequence<Xtilda, Wtilda>,
Sequence<ConvDilationW, ConvStrideW, 0>, Sequence<ConvDilationW, ConvStrideW, 0>,
true>{}), in_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
...@@ -270,7 +228,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -270,7 +228,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}), Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#endif
// GEMM // GEMM
constexpr auto gridwise_gemm = constexpr auto gridwise_gemm =
......
...@@ -115,7 +115,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -115,7 +115,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 1
// BlockSize = 256, each thread hold 64 data // BlockSize = 256, each thread hold 64 data
// for 1x1 weight, 8x8 input // for 1x1 weight, 8x8 input
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -161,10 +161,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -161,10 +161,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
#if 0 // debug #if 0 // debug
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t HtildaLeft = 0;
constexpr index_t GemmN = N * Htilda * Wtilda; constexpr index_t WtildaLeft = 0;
#else constexpr index_t HtildaRight = Htilda;
#if 1 constexpr index_t WtildaRight = Wtilda;
#else // doesn't produce correct result for stride=2 dilation=3
constexpr index_t HtildaLeft = math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]); constexpr index_t HtildaLeft = math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]); constexpr index_t WtildaLeft = math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
...@@ -176,18 +177,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -176,18 +177,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1), math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1),
ConvStrides{}[1]) + ConvStrides{}[1]) +
1; 1;
#else
constexpr index_t HtildaLeft = 0;
constexpr index_t WtildaLeft = 0;
constexpr index_t HtildaRight = Htilda;
constexpr index_t WtildaRight = Wtilda;
#endif #endif
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
constexpr index_t GemmM = C * Ytilda * Xtilda; constexpr index_t GemmM = C * Ytilda * Xtilda;
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim; constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
#endif
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock); math::integer_divide_ceil(GemmN, GemmNPerBlock);
......
...@@ -21,13 +21,28 @@ int main(int argc, char* argv[]) ...@@ -21,13 +21,28 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
// 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<2, 2>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
constexpr index_t K = 256; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -38,25 +53,26 @@ int main(int argc, char* argv[]) ...@@ -38,25 +53,26 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 64; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 256; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0> using RightPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 1x1 filter, 8x8 image // 1x1 filter, 8x8 image
constexpr index_t N = 256; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 8; constexpr index_t HI = 8;
constexpr index_t WI = 8; constexpr index_t WI = 8;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -68,10 +84,10 @@ int main(int argc, char* argv[]) ...@@ -68,10 +84,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x1 filter, 7x7 image // 1x1 filter, 7x7 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 7; constexpr index_t HI = 7;
constexpr index_t WI = 7; constexpr index_t WI = 7;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -98,7 +114,7 @@ int main(int argc, char* argv[]) ...@@ -98,7 +114,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 128;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -113,10 +129,10 @@ int main(int argc, char* argv[]) ...@@ -113,10 +129,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x1 filter, 17x17 input // 1x1 filter, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -128,7 +144,7 @@ int main(int argc, char* argv[]) ...@@ -128,7 +144,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 5x5 filter, 2x2 pad, 7x7 input // 5x5 filter, 2x2 pad, 7x7 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 48; constexpr index_t C = 128;
constexpr index_t HI = 7; constexpr index_t HI = 7;
constexpr index_t WI = 7; constexpr index_t WI = 7;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -143,10 +159,10 @@ int main(int argc, char* argv[]) ...@@ -143,10 +159,10 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
// 1x7 filter, 0x3 pad, 17x17 input // 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -155,13 +171,13 @@ int main(int argc, char* argv[]) ...@@ -155,13 +171,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 0 #elif 1
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 7; constexpr index_t Y = 7;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -173,10 +189,10 @@ int main(int argc, char* argv[]) ...@@ -173,10 +189,10 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment