Unverified Commit 4637621a authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Rectify flipped coordinate_transformation_mode logic in ROIAlign #2159 (#2214)



* Rectify flipped coordinate_transformation_mode logic in ROIAlign
* Handle both opset 10 and 16 versions
* Fix version check and clang tidy warning
Co-authored-by: default avatarDino Musić <dino.music@htecgroup.com>
parent a4957ab2
...@@ -124,7 +124,7 @@ struct roialign ...@@ -124,7 +124,7 @@ struct roialign
{ {
xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] +
(i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii];
xy[ii] = (coord_trans_mode == "output_half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; xy[ii] = (coord_trans_mode == "half_pixel") ? (xy[ii] - 0.5f) : xy[ii];
if(xy[ii] < -1.0 or xy[ii] > dims[ii]) if(xy[ii] < -1.0 or xy[ii] > dims[ii])
{ {
results[index] = pos_weight{}; results[index] = pos_weight{};
......
...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; } std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode =
if(contains(info.attributes, "coordinate_transformation_mode")) parser.opset_version >= 16 ? "half_pixel" : "output_half_pixel";
if(const auto* a = "coordinate_transformation_mode"; contains(info.attributes, a))
{ {
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); coord_trans_mode = info.attributes.at(a).s();
} }
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{ {
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
......
...@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler> ...@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler>
// coord_trans_mode // coord_trans_mode
auto ctm = v.at("coordinate_transformation_mode").to<std::string>(); auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f; float rois_offset = (ctm == "half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset); options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale // spatial_scale
......
...@@ -5951,7 +5951,13 @@ TEST_CASE(roialign_default_test) ...@@ -5951,7 +5951,13 @@ TEST_CASE(roialign_default_test)
auto rois = mm->add_parameter("rois", srois); auto rois = mm->add_parameter("rois", srois);
auto bi = mm->add_parameter("batch_ind", sbi); auto bi = mm->add_parameter("batch_ind", sbi);
auto r = mm->add_instruction(migraphx::make_op("roialign"), x, rois, bi); // Due to the onnx model using opset 12, the coordinate_transformation_mode should be set to
// output_half_pixel
auto r = mm->add_instruction(
migraphx::make_op("roialign", {{"coordinate_transformation_mode", "output_half_pixel"}}),
x,
rois,
bi);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("roialign_default_test.onnx"); auto prog = migraphx::parse_onnx("roialign_default_test.onnx");
......
...@@ -73,7 +73,7 @@ TEST_CASE(roialign_out_of_bound_test) ...@@ -73,7 +73,7 @@ TEST_CASE(roialign_out_of_bound_test)
}; };
{ {
auto p = create_program("output_half_pixel"); auto p = create_program("half_pixel");
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
...@@ -130,7 +130,7 @@ TEST_CASE(roialign_test) ...@@ -130,7 +130,7 @@ TEST_CASE(roialign_test)
}; };
{ {
auto p = create_program(); auto p = create_program("output_half_pixel");
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
...@@ -154,7 +154,7 @@ TEST_CASE(roialign_test) ...@@ -154,7 +154,7 @@ TEST_CASE(roialign_test)
} }
{ {
auto p = create_program("output_half_pixel"); auto p = create_program("half_pixel");
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
...@@ -175,7 +175,7 @@ TEST_CASE(roialign_test) ...@@ -175,7 +175,7 @@ TEST_CASE(roialign_test)
} }
{ {
auto p = create_program("output_half_pixel", migraphx::op::pooling_mode::max, 0); auto p = create_program("half_pixel", migraphx::op::pooling_mode::max, 0);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment