Unverified Commit a2e90b5d authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

Mode as enum for pooling and roi_align (#1091)

Changed the pooling values for two structures from strings to specialized enum classes. Many test and operator parsing changes to support this. Introduces one new source file, op_enums.cpp.
parent d9d17a11
......@@ -819,9 +819,9 @@ struct ref_apply
void apply_pooling(instruction_ref ins) const
{
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "max")
if(op.mode == op::pooling_mode::max)
mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average")
else if(op.mode == op::pooling_mode::average)
mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
}
};
......
......@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
op::pooling op{starts_with(opd.tf_name, "Max") ? "max" : "average"};
if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av"))
{
MIGRAPHX_THROW("tf pooling mode must be Max or Average");
}
op::pooling op{starts_with(opd.tf_name, "Max") ? op::pooling_mode::max
: op::pooling_mode::average};
if(contains(info.attributes, "strides"))
{
......
......@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(padded_img, channels, m);
auto l1 = create_conv(padded_img, channels, m);
auto l2 = m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img);
auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), padded_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape();
......
......@@ -55,7 +55,9 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(l_img, channels, m);
auto l1 = create_conv(l_img, channels, m);
auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {0, 0, 1, 1}}}), l_img);
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 1, 1}}}),
l_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
run_pass(m);
......@@ -76,7 +78,9 @@ TEST_CASE(rewrite_pad_symmetric)
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input});
m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {1, 1, 1, 1}}}),
m.add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {1, 1, 1, 1}}}),
l_img);
run_pass(m);
......
......@@ -191,10 +191,11 @@ TEST_CASE(averagepool_1d_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "average"}, {"padding", {0, 0}}, {"stride", {1}}, {"lengths", {3}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0}},
{"stride", {1}},
{"lengths", {3}}}),
l0);
auto prog = optimize_onnx("averagepool_1d_test.onnx");
......@@ -207,7 +208,7 @@ TEST_CASE(averagepool_3d_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 3}}}),
......@@ -223,7 +224,7 @@ TEST_CASE(averagepool_notset_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {2, 2, 2, 2}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
......@@ -244,7 +245,7 @@ TEST_CASE(averagepool_nt_cip_test)
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
......@@ -261,7 +262,7 @@ TEST_CASE(averagepool_same_lower_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
......@@ -282,7 +283,7 @@ TEST_CASE(averagepool_sl_cip_test)
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
......@@ -299,7 +300,7 @@ TEST_CASE(averagepool_same_upper_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
......@@ -669,10 +670,11 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto l6 = mm->add_instruction(
migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6);
auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6);
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {2, 2}}}),
l7);
auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx");
......@@ -693,10 +695,11 @@ TEST_CASE(conv_relu_maxpool_test)
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {2, 2}}}),
l6);
auto prog = optimize_onnx("conv_relu_maxpool_test.onnx");
......@@ -717,10 +720,11 @@ TEST_CASE(conv_relu_maxpool_x2_test)
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
auto l7 = mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
auto l7 = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {2, 2}}}),
l6);
auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
......@@ -732,10 +736,11 @@ TEST_CASE(conv_relu_maxpool_x2_test)
l9);
auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11);
auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12);
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {2, 2}}}),
l13);
auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx");
......@@ -1481,7 +1486,7 @@ TEST_CASE(globalavgpool_test)
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0};
......@@ -1498,7 +1503,7 @@ TEST_CASE(globalmaxpool_test)
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0};
......@@ -2491,10 +2496,11 @@ TEST_CASE(maxpool_notset_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0, 1, 1}}, {"stride", {2, 2}}, {"lengths", {6, 6}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 1, 1}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
input);
auto prog = optimize_onnx("maxpool_notset_test.onnx");
......@@ -2507,10 +2513,11 @@ TEST_CASE(maxpool_same_upper_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0, 1, 1}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
input);
auto prog = optimize_onnx("maxpool_same_upper_test.onnx");
......
......@@ -570,9 +570,11 @@ TEST_CASE(inconsistent_attr_shape)
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input,
weights);
throws_shape(
migraphx::make_op(
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}),
throws_shape(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1, 1}}}),
input);
}
......@@ -983,21 +985,24 @@ TEST_CASE(pooling_shape)
{
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape(
migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}),
throws_shape(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
{"stride", {0}},
{"lengths", {1}}}),
input);
expect_shape(
output,
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}),
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
input);
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1,
migraphx::make_op("pooling",
{{"mode", "max"},
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}},
......
......@@ -370,7 +370,7 @@ TEST_CASE(avgpool_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"average"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
......@@ -392,7 +392,7 @@ TEST_CASE(avgpool_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{"average"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2};
op.padding = {1};
op.stride = {2};
......@@ -439,7 +439,7 @@ TEST_CASE(avgpool_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"average"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {1, 1, 1};
......@@ -1658,7 +1658,7 @@ TEST_CASE(globalavgpool_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{"average"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
......@@ -1679,7 +1679,7 @@ TEST_CASE(globalmaxpool_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
......@@ -2582,10 +2582,11 @@ TEST_CASE(maxpool_test)
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(
migraphx::make_op(
"pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {3, 2}}}),
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}},
{"stride", {2, 2}},
{"lengths", {3, 2}}}),
al);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......@@ -2601,7 +2602,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
......@@ -2623,7 +2624,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
......@@ -2647,7 +2648,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
......@@ -2683,7 +2684,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {2, 2, 2};
......@@ -4038,7 +4039,8 @@ TEST_CASE(roialign_out_of_bound_test)
TEST_CASE(roialign_test)
{
auto create_program = [](const std::string& trans_mode = "half_pixel",
const std::string& pooling_mode = "avg",
const migraphx::op::pooling_mode pooling_mode =
migraphx::op::pooling_mode::average,
int64_t sampling_ratio = 2) {
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -4125,7 +4127,7 @@ TEST_CASE(roialign_test)
}
{
auto p = create_program("output_half_pixel", "max", 0);
auto p = create_program("output_half_pixel", migraphx::op::pooling_mode::max, 0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
......
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
......@@ -22,7 +23,7 @@ static void opt_pooling(migraphx::module& m)
TEST_CASE(rewrite_pooling_test)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&](const std::string& mode) {
auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
......@@ -46,15 +47,16 @@ TEST_CASE(rewrite_pooling_test)
return m;
};
auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) {
auto test_rewrite = [&](const migraphx::op::pooling_mode mode, const migraphx::operation& op) {
migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(op);
opt_pooling(m1);
EXPECT(m1 == m2);
};
test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite(migraphx::op::pooling_mode::average,
migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite(migraphx::op::pooling_mode::max, migraphx::make_op("reduce_max", {{"axes", {1}}}));
}
TEST_CASE(rewrite_avepooling_na1_test)
......@@ -64,8 +66,9 @@ TEST_CASE(rewrite_avepooling_na1_test)
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
auto ret =
m.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 1, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
......@@ -88,8 +91,9 @@ TEST_CASE(rewrite_avepooling_na2_test)
migraphx::module m;
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
auto ret =
m.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0}},
{"stride", {1, 2, 1}},
{"lengths", {3, 4, 5}}}),
......@@ -113,7 +117,7 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", "max"},
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}),
......@@ -135,7 +139,7 @@ TEST_CASE(literal_rewrite_pooling_test)
std::vector<float> data(s.elements());
std::iota(data.begin(), data.end(), 1.0f);
auto pooling_program = [&](const std::string& mode) {
auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -163,7 +167,8 @@ TEST_CASE(literal_rewrite_pooling_test)
return p;
};
auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) {
auto test_rewrite_pooling = [&](const migraphx::op::pooling_mode mode,
const migraphx::operation& op) {
migraphx::program p1 = pooling_program(mode);
migraphx::program p2 = opt_program(op);
p1.compile(migraphx::ref::target{});
......@@ -174,8 +179,10 @@ TEST_CASE(literal_rewrite_pooling_test)
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
};
test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite_pooling(migraphx::op::pooling_mode::max,
migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite_pooling(migraphx::op::pooling_mode::average,
migraphx::make_op("reduce_mean", {{"axes", {1}}}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4,6 +4,7 @@
#include <migraphx/instruction.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/matcher.hpp>
......@@ -463,8 +464,9 @@ TEST_CASE(conv_pooling_dot)
auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
auto ap =
m1.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
......@@ -509,8 +511,9 @@ TEST_CASE(conv_pooling_dot)
auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
auto ap =
m2.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
......@@ -567,8 +570,9 @@ TEST_CASE(mobilenet_snippet)
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
auto ap =
mm.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
......
......@@ -649,8 +649,8 @@ TEST_CASE(pooling_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"};
migraphx::op::pooling max_pool_op{"max"};
migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average};
migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max};
avg_pool_op.stride = {2, 2};
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
......
......@@ -12,7 +12,7 @@ struct test_avg_pooling_1d : verify_program<test_avg_pooling_1d>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto op = migraphx::op::pooling{"average", {0}, {1}, {3}};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average, {0}, {1}, {3}};
mm->add_instruction(op, input);
return p;
}
......
......@@ -12,7 +12,8 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}};
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}};
mm->add_instruction(op, input);
return p;
}
......
......@@ -12,7 +12,8 @@ struct test_avg_pooling_3d_opt : verify_program<test_avg_pooling_3d_opt>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 2, 3, 3, 3}});
auto op = migraphx::op::pooling{"average", {0, 0, 0}, {1, 1, 1}, {3, 3, 3}};
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {0, 0, 0}, {1, 1, 1}, {3, 3, 3}};
mm->add_instruction(op, input);
return p;
}
......
......@@ -10,9 +10,11 @@ struct test_avg_pooling_ceil_3d : verify_program<test_avg_pooling_ceil_3d>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
mm->add_instruction(op, input);
return p;
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
struct test_concat_pooling : verify_program<test_concat_pooling>
{
......@@ -18,8 +19,9 @@ struct test_concat_pooling : verify_program<test_concat_pooling>
auto concat_t = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), concat);
auto pooling = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
auto pooling =
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0}},
{"stride", {1, 1}},
{"lengths", {8, 8}}}),
......
......@@ -3,6 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{
......@@ -29,7 +30,7 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance);
auto relu = mm->add_instruction(migraphx::make_op("relu"), bn);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}},
{"stride", {2, 2}},
{"lengths", {3, 3}}}),
......
......@@ -3,6 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{
......@@ -47,7 +48,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
auto relu = mm->add_instruction(migraphx::make_op("relu"), add);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}},
{"stride", {2, 2}},
{"lengths", {3, 3}}}),
......
......@@ -3,6 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_pooling : verify_program<test_conv_pooling>
{
......@@ -15,7 +16,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), conv);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
mm->add_instruction(migraphx::make_op("relu"), pooling);
return p;
}
......
......@@ -13,7 +13,7 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
mm->add_instruction(op, input);
......
......@@ -13,7 +13,7 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling>
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"};
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
mm->add_instruction(op, input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment