Commit 7e808fe1 authored by Chao Liu's avatar Chao Liu
Browse files

debugged bwd data v2r1

parent 9818ea05
......@@ -116,14 +116,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr index_t HtildaLeft =
math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
math::max(0,
math::integer_divide_floor(InLeftPads{}[0] - ConvDilationH * (Ytilda - 1),
ConvStrides{}[0]));
constexpr index_t WtildaLeft =
math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
math::max(0,
math::integer_divide_floor(InLeftPads{}[1] - ConvDilationW * (Xtilda - 1),
ConvStrides{}[1]));
constexpr index_t HtildaRight =
math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1;
constexpr index_t WtildaRight =
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1;
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
......
......@@ -162,13 +162,19 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda)), ConvStrideW);
constexpr index_t HtildaLeft = math::integer_divide_floor(InLeftPads{}[0], ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(InLeftPads{}[1], ConvStrides{}[1]);
constexpr index_t HtildaRight =
math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1;
constexpr index_t WtildaRight =
math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1;
constexpr index_t HtildaLeft =
math::max(0,
math::integer_divide_floor(InLeftPads{}[0] - ConvDilationH * (Ytilda - 1),
ConvStrides{}[0]));
constexpr index_t WtildaLeft =
math::max(0,
math::integer_divide_floor(InLeftPads{}[1] - ConvDilationW * (Xtilda - 1),
ConvStrides{}[1]));
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
constexpr index_t WtildaRight = math::min(
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
......
......@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 1
#if 0
// 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......@@ -31,8 +31,8 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<3, 3>;
using ConvDilations = Sequence<2, 2>;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<3, 3>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
......@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
#elif 1
// 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment