Commit bf779629 authored by rocking's avatar rocking
Browse files

calculate Do Ho Wo for the dilation

parent 81d0135c
...@@ -61,8 +61,10 @@ bool pool_test(bool do_verification, ...@@ -61,8 +61,10 @@ bool pool_test(bool do_verification,
1, // InSrcOutDstVectorSize 1, // InSrcOutDstVectorSize
false>; // IsFastestDimReduced false>; // IsFastestDimReduced
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
const std::vector<ck::index_t> window_spatial_lengths{Y, X}; const std::vector<ck::index_t> window_spatial_lengths{Y, X};
const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w}; const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w};
......
...@@ -88,9 +88,12 @@ bool pool3d_test(bool do_verification, ...@@ -88,9 +88,12 @@ bool pool3d_test(bool do_verification,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; const ck::index_t Zs = (Z - 1) * window_dilation_d + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
const std::vector<ck::index_t> window_spatial_lengths{Z, Y, X}; const std::vector<ck::index_t> window_spatial_lengths{Z, Y, X};
const std::vector<ck::index_t> window_strides{ const std::vector<ck::index_t> window_strides{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment