"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "7ecb2de4c827538ae038c3f10773299e7a640d3e"
Commit 1bacc1df authored by kahmed10's avatar kahmed10 Committed by mvermeulen
Browse files

Pad calculation fix for nasnet (#393)

* fix pad calc

* simplify ceil calc and remove extra vars

* change dilation calculation, add tests

* formatting

* formatting
parent 23124b09
......@@ -15,11 +15,12 @@ inline void calculate_padding(int64_t idx,
int64_t dilation,
int64_t weight_dim)
{
int64_t output_dim = input_dim / stride;
int64_t pad = std::max(static_cast<int64_t>(0),
(output_dim - 1) * stride + dilation * weight_dim - input_dim);
pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
int64_t output_dim = (input_dim + stride - 1) / stride; // round up result
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
int64_t pad =
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -398,11 +398,9 @@ struct tf_parser
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2];
size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
......@@ -486,11 +484,9 @@ struct tf_parser
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2];
size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
......@@ -722,11 +718,9 @@ struct tf_parser
{
op.padding_mode = op::padding_mode_t::same;
auto input_dims = l0->get_shape().lens();
size_t input_h = input_dims[2];
size_t input_w = input_dims[3];
std::vector<int64_t> pads(input_dims.size());
calculate_padding(0, pads, input_h, op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_w, op.stride[1], 1, op.lengths[1]);
calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
if(pads[0] != pads[2] || pads[1] != pads[3])
{
......
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp>
#include "test.hpp"
TEST_CASE(pad_calc_test_no_pad)
{
std::vector<int64_t> golden_pads{0, 0, 0, 0};
std::vector<int64_t> pads{0, 0, 0, 0};
// 1x1 filter size
migraphx::calculate_padding(0, pads, 16, 1, 1, 1);
migraphx::calculate_padding(1, pads, 16, 1, 1, 1);
EXPECT(pads == golden_pads);
}
TEST_CASE(pad_calc_test_pad_by_1)
{
std::vector<int64_t> golden_pads{1, 1, 1, 1};
std::vector<int64_t> pads{0, 0, 0, 0};
// 3x3 filter size
migraphx::calculate_padding(0, pads, 16, 1, 1, 3);
migraphx::calculate_padding(1, pads, 16, 1, 1, 3);
EXPECT(pads == golden_pads);
}
TEST_CASE(pad_calc_test_pad_by_1_asym_2x2_filter)
{
std::vector<int64_t> golden_pads{0, 0, 1, 1};
std::vector<int64_t> pads{0, 0, 0, 0};
// 2x2 filter size
migraphx::calculate_padding(0, pads, 16, 1, 1, 2);
migraphx::calculate_padding(1, pads, 16, 1, 1, 2);
EXPECT(pads == golden_pads);
}
TEST_CASE(pad_calc_test_pad_by_2)
{
std::vector<int64_t> golden_pads{2, 2, 2, 2};
std::vector<int64_t> pads{0, 0, 0, 0};
// 5x5 filter size
migraphx::calculate_padding(0, pads, 16, 1, 1, 5);
migraphx::calculate_padding(1, pads, 16, 1, 1, 5);
EXPECT(pads == golden_pads);
}
TEST_CASE(pad_calc_test_pad_by_1_asym_stride_2)
{
std::vector<int64_t> golden_pads{0, 0, 1, 1};
std::vector<int64_t> pads{0, 0, 0, 0};
// 3x3 filter size
migraphx::calculate_padding(0, pads, 16, 2, 1, 3);
migraphx::calculate_padding(1, pads, 16, 2, 1, 3);
EXPECT(pads == golden_pads);
}
TEST_CASE(pad_calc_test_dilation_2)
{
std::vector<int64_t> golden_pads{2, 2, 2, 2};
std::vector<int64_t> pads{0, 0, 0, 0};
// 3x3 filter size
migraphx::calculate_padding(0, pads, 16, 1, 2, 3);
migraphx::calculate_padding(1, pads, 16, 1, 2, 3);
EXPECT(pads == golden_pads);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment