Unverified Commit a2e90b5d authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

Mode as enum for pooling and roi_align (#1091)

Changed the pooling values for two structures from strings to specialized enum classes. Many test and operator parsing changes to support this. Introduces one new source file, op_enums.cpp.
parent d9d17a11
...@@ -12,7 +12,8 @@ struct test_max_pooling_ceil_3d : verify_program<test_max_pooling_ceil_3d> ...@@ -12,7 +12,8 @@ struct test_max_pooling_ceil_3d : verify_program<test_max_pooling_ceil_3d>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"max", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true}; auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::max, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -12,7 +12,7 @@ struct test_pooling_autopad : verify_program<test_pooling_autopad> ...@@ -12,7 +12,7 @@ struct test_pooling_autopad : verify_program<test_pooling_autopad>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 3, 63, 63}}; migraphx::shape s0{migraphx::shape::float_type, {1, 3, 63, 63}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::op::pooling op{"max"}; migraphx::op::pooling op{migraphx::op::pooling_mode::max};
op.lengths = {2, 2}; op.lengths = {2, 2};
op.stride = {2, 2}; op.stride = {2, 2};
mm->add_instruction(op, l0); mm->add_instruction(op, l0);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_roialign_nondefault : verify_program<test_roialign_nondefault> struct test_roialign_nondefault : verify_program<test_roialign_nondefault>
{ {
...@@ -23,7 +24,7 @@ struct test_roialign_nondefault : verify_program<test_roialign_nondefault> ...@@ -23,7 +24,7 @@ struct test_roialign_nondefault : verify_program<test_roialign_nondefault>
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("roialign", migraphx::make_op("roialign",
{{"coordinate_transformation_mode", "output_half_pixel"}, {{"coordinate_transformation_mode", "output_half_pixel"},
{"mode", "max"}, {"mode", migraphx::op::pooling_mode::max},
{"spatial_scale", 1.0}, {"spatial_scale", 1.0},
{"output_height", 5}, {"output_height", 5},
{"output_width", 5}, {"output_width", 5},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment