Commit 7f679044 authored by Chao Liu's avatar Chao Liu
Browse files

debugged bwd data v2r1

parent 7e808fe1
......@@ -115,14 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr index_t HtildaLeft =
math::max(0,
math::integer_divide_floor(InLeftPads{}[0] - ConvDilationH * (Ytilda - 1),
ConvStrides{}[0]));
constexpr index_t WtildaLeft =
math::max(0,
math::integer_divide_floor(InLeftPads{}[1] - ConvDilationW * (Xtilda - 1),
ConvStrides{}[1]));
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
......
......@@ -162,14 +162,10 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t Wtilda =
Wo + math::integer_divide_ceil(ConvDilationW * (X - (X % Xtilda)), ConvStrideW);
constexpr index_t HtildaLeft =
math::max(0,
math::integer_divide_floor(InLeftPads{}[0] - ConvDilationH * (Ytilda - 1),
ConvStrides{}[0]));
constexpr index_t WtildaLeft =
math::max(0,
math::integer_divide_floor(InLeftPads{}[1] - ConvDilationW * (Xtilda - 1),
ConvStrides{}[1]));
constexpr index_t HtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
constexpr index_t WtildaLeft = math::integer_divide_floor(
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
constexpr index_t HtildaRight = math::min(
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
......
......@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 0
#if 1
// 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128;
constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment