Commit 9818ea05 authored by Chao Liu's avatar Chao Liu
Browse files

debugged bwd data v2r1

parent 6904d163
...@@ -90,8 +90,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -90,8 +90,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = Ho + (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda); constexpr index_t Htilda =
constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); Ho + math::integer_divide_ceil(ConvDilationH * (Y - (Y % Ytilda)), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda)), ConvStrideW);
// weight tensor // weight tensor
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
...@@ -113,26 +115,15 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -113,26 +115,15 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug
constexpr index_t HtildaLeft = 0;
constexpr index_t WtildaLeft = 0;
constexpr index_t HtildaRight = Htilda;
constexpr index_t WtildaRight = Wtilda;
#else // doesn't produce correct result for stride=2 dilation=3
constexpr index_t HtildaLeft = constexpr index_t HtildaLeft =
math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]); math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft = constexpr index_t WtildaLeft =
math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]); math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
constexpr index_t HtildaRight = constexpr index_t HtildaRight =
math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1 - ConvDilations{}[0] * (Ytilda - 1), math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1;
ConvStrides{}[0]) +
1;
constexpr index_t WtildaRight = constexpr index_t WtildaRight =
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1), math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1;
ConvStrides{}[1]) +
1;
#endif
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
...@@ -173,7 +164,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -173,7 +164,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0 // debug #if 1 // debug
constexpr bool in_skip_all_out_of_bound_check = false; constexpr bool in_skip_all_out_of_bound_check = false;
#else // doesn't produce correct result for stride=2 dilation=1 #else // doesn't produce correct result for stride=2 dilation=1
constexpr bool in_skip_all_out_of_bound_check = true; constexpr bool in_skip_all_out_of_bound_check = true;
......
...@@ -55,7 +55,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -55,7 +55,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// BlockSize = 256, each thread hold 64 data // BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -157,27 +157,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i ...@@ -157,27 +157,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda); constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda); constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t Htilda = Ho + (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda); constexpr index_t Htilda =
constexpr index_t Wtilda = Wo + (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda); Ho + math::integer_divide_ceil(ConvDilationH * (Y - (Y % Ytilda)), ConvStrideH);
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda)), ConvStrideW);
#if 0 // debug
constexpr index_t HtildaLeft = 0;
constexpr index_t WtildaLeft = 0;
constexpr index_t HtildaRight = Htilda;
constexpr index_t WtildaRight = Wtilda;
#else // doesn't produce correct result for stride=2 dilation=3
constexpr index_t HtildaLeft = math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]); constexpr index_t HtildaLeft = math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]); constexpr index_t WtildaLeft = math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
constexpr index_t HtildaRight = constexpr index_t HtildaRight =
math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1 - ConvDilations{}[0] * (Ytilda - 1), math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1;
ConvStrides{}[0]) +
1;
constexpr index_t WtildaRight = constexpr index_t WtildaRight =
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1 - ConvDilations{}[1] * (Xtilda - 1), math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1;
ConvStrides{}[1]) +
1;
#endif
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft; constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft; constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
......
...@@ -31,7 +31,7 @@ int main(int argc, char* argv[]) ...@@ -31,7 +31,7 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>; using ConvStrides = Sequence<3, 3>;
using ConvDilations = Sequence<2, 2>; using ConvDilations = Sequence<2, 2>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
...@@ -171,7 +171,7 @@ int main(int argc, char* argv[]) ...@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 1 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment