Commit 9c4573be authored by Chao Liu's avatar Chao Liu
Browse files

debugged bwd data v2r1

parent 7f679044
...@@ -91,9 +91,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -91,9 +91,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - (Y % Ytilda)), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - (Y % Ytilda) - 1), ConvStrideH);
constexpr index_t Wtilda = constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda)), ConvStrideW); Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda) - 1), ConvStrideW);
// weight tensor // weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
...@@ -166,7 +166,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -166,7 +166,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
#if 1 // debug #if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false; constexpr bool in_skip_all_out_of_bound_check = false;
#else // doesn't produce correct result for stride=2 dilation=1 #else
constexpr bool in_skip_all_out_of_bound_check = true; constexpr bool in_skip_all_out_of_bound_check = true;
#endif #endif
......
...@@ -158,9 +158,9 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -158,9 +158,9 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = constexpr index_t Htilda =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - (Y % Ytilda)), ConvStrideH); Ho + math::integer_divide_ceil(ConvDilationH * (Y - (Y % Ytilda) - 1), ConvStrideH);
constexpr index_t Wtilda = constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda)), ConvStrideW); Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda) - 1), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor( constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]); math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment