"src/include/blockwise_batched_gemm.hpp" did not exist on "c075d3f7d91079d28340cda89d51e15117493968"
Unverified Commit a2e90b5d authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

Mode as enum for pooling and roi_align (#1091)

Changed the pooling values for two structures from strings to specialized enum classes. Many test and operator parsing changes to support this. Introduces one new source file, op_enums.cpp.
parent d9d17a11
...@@ -819,9 +819,9 @@ struct ref_apply ...@@ -819,9 +819,9 @@ struct ref_apply
void apply_pooling(instruction_ref ins) const void apply_pooling(instruction_ref ins) const
{ {
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "max") if(op.mode == op::pooling_mode::max)
mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs()); mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average") else if(op.mode == op::pooling_mode::average)
mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs()); mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
} }
}; };
......
...@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser::node_info info, tf_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
op::pooling op{starts_with(opd.tf_name, "Max") ? "max" : "average"}; if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av"))
{
MIGRAPHX_THROW("tf pooling mode must be Max or Average");
}
op::pooling op{starts_with(opd.tf_name, "Max") ? op::pooling_mode::max
: op::pooling_mode::average};
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
......
...@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad) ...@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(padded_img, channels, m); auto l0 = create_im2col(padded_img, channels, m);
auto l1 = create_conv(padded_img, channels, m); auto l1 = create_conv(padded_img, channels, m);
auto l2 = m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img); auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), padded_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape(); auto s0 = l0->get_shape();
......
...@@ -55,7 +55,9 @@ TEST_CASE(rewrite_pad) ...@@ -55,7 +55,9 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(l_img, channels, m); auto l0 = create_im2col(l_img, channels, m);
auto l1 = create_conv(l_img, channels, m); auto l1 = create_conv(l_img, channels, m);
auto l2 = m.add_instruction( auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {0, 0, 1, 1}}}), l_img); migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 1, 1}}}),
l_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
run_pass(m); run_pass(m);
...@@ -76,8 +78,10 @@ TEST_CASE(rewrite_pad_symmetric) ...@@ -76,8 +78,10 @@ TEST_CASE(rewrite_pad_symmetric)
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}}; migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input}); auto l_img = m.add_literal(migraphx::literal{s_img, input});
m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {1, 1, 1, 1}}}), m.add_instruction(
l_img); migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max}, {"padding", {1, 1, 1, 1}}}),
l_img);
run_pass(m); run_pass(m);
EXPECT(std::none_of( EXPECT(std::none_of(
......
...@@ -191,11 +191,12 @@ TEST_CASE(averagepool_1d_test) ...@@ -191,11 +191,12 @@ TEST_CASE(averagepool_1d_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5}});
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::average},
"pooling", {"padding", {0, 0}},
{{"mode", "average"}, {"padding", {0, 0}}, {"stride", {1}}, {"lengths", {3}}}), {"stride", {1}},
l0); {"lengths", {3}}}),
l0);
auto prog = optimize_onnx("averagepool_1d_test.onnx"); auto prog = optimize_onnx("averagepool_1d_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -207,7 +208,7 @@ TEST_CASE(averagepool_3d_test) ...@@ -207,7 +208,7 @@ TEST_CASE(averagepool_3d_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 5, 5, 5}});
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0, 0, 0}}, {"padding", {0, 0, 0, 0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 3, 3}}}), {"lengths", {3, 3, 3}}}),
...@@ -223,7 +224,7 @@ TEST_CASE(averagepool_notset_test) ...@@ -223,7 +224,7 @@ TEST_CASE(averagepool_notset_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling", auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {2, 2, 2, 2}}, {"padding", {2, 2, 2, 2}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {6, 6}}}), {"lengths", {6, 6}}}),
...@@ -244,7 +245,7 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -244,7 +245,7 @@ TEST_CASE(averagepool_nt_cip_test)
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1}; std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling", auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {6, 6}}}), {"lengths", {6, 6}}}),
...@@ -261,7 +262,7 @@ TEST_CASE(averagepool_same_lower_test) ...@@ -261,7 +262,7 @@ TEST_CASE(averagepool_same_lower_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling", auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}}, {"padding", {1, 1, 1, 1}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {2, 2}}}), {"lengths", {2, 2}}}),
...@@ -282,7 +283,7 @@ TEST_CASE(averagepool_sl_cip_test) ...@@ -282,7 +283,7 @@ TEST_CASE(averagepool_sl_cip_test)
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0}; std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling", auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {2, 2}}}), {"lengths", {2, 2}}}),
...@@ -299,7 +300,7 @@ TEST_CASE(averagepool_same_upper_test) ...@@ -299,7 +300,7 @@ TEST_CASE(averagepool_same_upper_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling", auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}}, {"padding", {1, 1, 1, 1}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {2, 2}}}), {"lengths", {2, 2}}}),
...@@ -669,11 +670,12 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -669,11 +670,12 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto l6 = mm->add_instruction( auto l6 = mm->add_instruction(
migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6); migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6);
auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6); auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6);
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0, 0, 0}},
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}), {"stride", {2, 2}},
l7); {"lengths", {2, 2}}}),
l7);
auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx"); auto prog = optimize_onnx("conv_bn_relu_maxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -693,11 +695,12 @@ TEST_CASE(conv_relu_maxpool_test) ...@@ -693,11 +695,12 @@ TEST_CASE(conv_relu_maxpool_test)
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5); auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0, 0, 0}},
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}), {"stride", {2, 2}},
l6); {"lengths", {2, 2}}}),
l6);
auto prog = optimize_onnx("conv_relu_maxpool_test.onnx"); auto prog = optimize_onnx("conv_relu_maxpool_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -717,11 +720,12 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -717,11 +720,12 @@ TEST_CASE(conv_relu_maxpool_x2_test)
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5); auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
auto l7 = mm->add_instruction( auto l7 = mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0, 0, 0}},
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}), {"stride", {2, 2}},
l6); {"lengths", {2, 2}}}),
l6);
auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}}); auto l8 = mm->add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = mm->add_parameter("4", {migraphx::shape::float_type, {1}}); auto l9 = mm->add_parameter("4", {migraphx::shape::float_type, {1}});
...@@ -732,11 +736,12 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -732,11 +736,12 @@ TEST_CASE(conv_relu_maxpool_x2_test)
l9); l9);
auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11); auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11);
auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12); auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12);
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0, 0, 0}},
{{"mode", "max"}, {"padding", {0, 0, 0, 0}}, {"stride", {2, 2}}, {"lengths", {2, 2}}}), {"stride", {2, 2}},
l13); {"lengths", {2, 2}}}),
l13);
auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx"); auto prog = optimize_onnx("conv_relu_maxpool_x2_test.onnx");
...@@ -1481,7 +1486,7 @@ TEST_CASE(globalavgpool_test) ...@@ -1481,7 +1486,7 @@ TEST_CASE(globalavgpool_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0}; op.padding = {0, 0, 0, 0};
...@@ -1498,7 +1503,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -1498,7 +1503,7 @@ TEST_CASE(globalmaxpool_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
op.padding = {0, 0, 0, 0}; op.padding = {0, 0, 0, 0};
...@@ -2491,11 +2496,12 @@ TEST_CASE(maxpool_notset_test) ...@@ -2491,11 +2496,12 @@ TEST_CASE(maxpool_notset_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0, 1, 1}},
{{"mode", "max"}, {"padding", {0, 0, 1, 1}}, {"stride", {2, 2}}, {"lengths", {6, 6}}}), {"stride", {2, 2}},
input); {"lengths", {6, 6}}}),
input);
auto prog = optimize_onnx("maxpool_notset_test.onnx"); auto prog = optimize_onnx("maxpool_notset_test.onnx");
...@@ -2507,11 +2513,12 @@ TEST_CASE(maxpool_same_upper_test) ...@@ -2507,11 +2513,12 @@ TEST_CASE(maxpool_same_upper_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0, 1, 1}},
{{"mode", "max"}, {"padding", {0, 0, 1, 1}}, {"stride", {1, 1}}, {"lengths", {2, 2}}}), {"stride", {1, 1}},
input); {"lengths", {2, 2}}}),
input);
auto prog = optimize_onnx("maxpool_same_upper_test.onnx"); auto prog = optimize_onnx("maxpool_same_upper_test.onnx");
......
...@@ -570,10 +570,12 @@ TEST_CASE(inconsistent_attr_shape) ...@@ -570,10 +570,12 @@ TEST_CASE(inconsistent_attr_shape)
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}), {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input, input,
weights); weights);
throws_shape( throws_shape(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1, 1}}}), {"padding", {1}},
input); {"stride", {0}},
{"lengths", {1, 1}}}),
input);
} }
template <class T> template <class T>
...@@ -983,21 +985,24 @@ TEST_CASE(pooling_shape) ...@@ -983,21 +985,24 @@ TEST_CASE(pooling_shape)
{ {
migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
throws_shape( throws_shape(migraphx::make_op("pooling",
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max},
{{"mode", "max"}, {"padding", {1}}, {"stride", {0}}, {"lengths", {1}}}), {"padding", {1}},
input); {"stride", {0}},
expect_shape( {"lengths", {1}}}),
output, input);
migraphx::make_op( expect_shape(output,
"pooling", migraphx::make_op("pooling",
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {3, 3}}, {"lengths", {1, 1}}}), {{"mode", migraphx::op::pooling_mode::max},
input); {"padding", {0, 0}},
{"stride", {3, 3}},
{"lengths", {1, 1}}}),
input);
migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}}; migraphx::shape output1{migraphx::shape::float_type, {4, 3, 2, 2}};
expect_shape(output1, expect_shape(output1,
migraphx::make_op("pooling", migraphx::make_op("pooling",
{{"mode", "max"}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0}}, {"padding", {0, 0}},
{"stride", {3, 3}}, {"stride", {3, 3}},
{"lengths", {1, 1}}, {"lengths", {1, 1}},
......
...@@ -370,7 +370,7 @@ TEST_CASE(avgpool_test) ...@@ -370,7 +370,7 @@ TEST_CASE(avgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {1}; op.stride = {1};
...@@ -392,7 +392,7 @@ TEST_CASE(avgpool_test) ...@@ -392,7 +392,7 @@ TEST_CASE(avgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2}; op.lengths = {2};
op.padding = {1}; op.padding = {1};
op.stride = {2}; op.stride = {2};
...@@ -439,7 +439,7 @@ TEST_CASE(avgpool_test) ...@@ -439,7 +439,7 @@ TEST_CASE(avgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
op.lengths = {2, 2, 2}; op.lengths = {2, 2, 2};
op.padding = {0, 0, 0}; op.padding = {0, 0, 0};
op.stride = {1, 1, 1}; op.stride = {1, 1, 1};
...@@ -1658,7 +1658,7 @@ TEST_CASE(globalavgpool_test) ...@@ -1658,7 +1658,7 @@ TEST_CASE(globalavgpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = s.lens(); auto lens = s.lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
...@@ -1679,7 +1679,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -1679,7 +1679,7 @@ TEST_CASE(globalmaxpool_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = s.lens(); auto lens = s.lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
...@@ -2582,11 +2582,12 @@ TEST_CASE(maxpool_test) ...@@ -2582,11 +2582,12 @@ TEST_CASE(maxpool_test)
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802}; 0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction( mm->add_instruction(migraphx::make_op("pooling",
migraphx::make_op( {{"mode", migraphx::op::pooling_mode::max},
"pooling", {"padding", {0, 0}},
{{"mode", "max"}, {"padding", {0, 0}}, {"stride", {2, 2}}, {"lengths", {3, 2}}}), {"stride", {2, 2}},
al); {"lengths", {3, 2}}}),
al);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(36); std::vector<float> results_vector(36);
...@@ -2601,7 +2602,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2601,7 +2602,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {1}; op.stride = {1};
...@@ -2623,7 +2624,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2623,7 +2624,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {2}; op.stride = {2};
...@@ -2647,7 +2648,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2647,7 +2648,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2}; op.lengths = {2};
op.padding = {0}; op.padding = {0};
op.stride = {2}; op.stride = {2};
...@@ -2683,7 +2684,7 @@ TEST_CASE(maxpool_test_1D_3D) ...@@ -2683,7 +2684,7 @@ TEST_CASE(maxpool_test_1D_3D)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
op.lengths = {2, 2, 2}; op.lengths = {2, 2, 2};
op.padding = {0, 0, 0}; op.padding = {0, 0, 0};
op.stride = {2, 2, 2}; op.stride = {2, 2, 2};
...@@ -4037,9 +4038,10 @@ TEST_CASE(roialign_out_of_bound_test) ...@@ -4037,9 +4038,10 @@ TEST_CASE(roialign_out_of_bound_test)
TEST_CASE(roialign_test) TEST_CASE(roialign_test)
{ {
auto create_program = [](const std::string& trans_mode = "half_pixel", auto create_program = [](const std::string& trans_mode = "half_pixel",
const std::string& pooling_mode = "avg", const migraphx::op::pooling_mode pooling_mode =
int64_t sampling_ratio = 2) { migraphx::op::pooling_mode::average,
int64_t sampling_ratio = 2) {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}}; migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}};
...@@ -4125,7 +4127,7 @@ TEST_CASE(roialign_test) ...@@ -4125,7 +4127,7 @@ TEST_CASE(roialign_test)
} }
{ {
auto p = create_program("output_half_pixel", "max", 0); auto p = create_program("output_half_pixel", migraphx::op::pooling_mode::max, 0);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
......
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
...@@ -22,7 +23,7 @@ static void opt_pooling(migraphx::module& m) ...@@ -22,7 +23,7 @@ static void opt_pooling(migraphx::module& m)
TEST_CASE(rewrite_pooling_test) TEST_CASE(rewrite_pooling_test)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};
auto pooling_program = [&](const std::string& mode) { auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret = m.add_instruction(migraphx::make_op("pooling",
...@@ -46,15 +47,16 @@ TEST_CASE(rewrite_pooling_test) ...@@ -46,15 +47,16 @@ TEST_CASE(rewrite_pooling_test)
return m; return m;
}; };
auto test_rewrite = [&](const std::string& mode, const migraphx::operation& op) { auto test_rewrite = [&](const migraphx::op::pooling_mode mode, const migraphx::operation& op) {
migraphx::module m1 = pooling_program(mode); migraphx::module m1 = pooling_program(mode);
migraphx::module m2 = opt_program(op); migraphx::module m2 = opt_program(op);
opt_pooling(m1); opt_pooling(m1);
EXPECT(m1 == m2); EXPECT(m1 == m2);
}; };
test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}})); test_rewrite(migraphx::op::pooling_mode::average,
test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}})); migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite(migraphx::op::pooling_mode::max, migraphx::make_op("reduce_max", {{"axes", {1}}}));
} }
TEST_CASE(rewrite_avepooling_na1_test) TEST_CASE(rewrite_avepooling_na1_test)
...@@ -64,12 +66,13 @@ TEST_CASE(rewrite_avepooling_na1_test) ...@@ -64,12 +66,13 @@ TEST_CASE(rewrite_avepooling_na1_test)
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret =
{{"mode", "average"}, m.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 1, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1, 1}}, {"padding", {0, 1, 0}},
{"lengths", {3, 4, 5}}}), {"stride", {1, 1, 1}},
input); {"lengths", {3, 4, 5}}}),
input);
m.add_return({ret}); m.add_return({ret});
return m; return m;
}; };
...@@ -88,12 +91,13 @@ TEST_CASE(rewrite_avepooling_na2_test) ...@@ -88,12 +91,13 @@ TEST_CASE(rewrite_avepooling_na2_test)
migraphx::module m; migraphx::module m;
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret =
{{"mode", "average"}, m.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 2, 1}}, {"padding", {0, 0, 0}},
{"lengths", {3, 4, 5}}}), {"stride", {1, 2, 1}},
input); {"lengths", {3, 4, 5}}}),
input);
m.add_return({ret}); m.add_return({ret});
return m; return m;
}; };
...@@ -113,7 +117,7 @@ TEST_CASE(rewrite_avepooling_na3_test) ...@@ -113,7 +117,7 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto input = m.add_parameter("x", s); auto input = m.add_parameter("x", s);
auto ret = m.add_instruction(migraphx::make_op("pooling", auto ret = m.add_instruction(migraphx::make_op("pooling",
{{"mode", "max"}, {{"mode", migraphx::op::pooling_mode::max},
{"padding", {0, 0, 0}}, {"padding", {0, 0, 0}},
{"stride", {1, 1, 1}}, {"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}), {"lengths", {3, 3, 5}}}),
...@@ -135,7 +139,7 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -135,7 +139,7 @@ TEST_CASE(literal_rewrite_pooling_test)
std::vector<float> data(s.elements()); std::vector<float> data(s.elements());
std::iota(data.begin(), data.end(), 1.0f); std::iota(data.begin(), data.end(), 1.0f);
auto pooling_program = [&](const std::string& mode) { auto pooling_program = [&](const migraphx::op::pooling_mode mode) {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -163,7 +167,8 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -163,7 +167,8 @@ TEST_CASE(literal_rewrite_pooling_test)
return p; return p;
}; };
auto test_rewrite_pooling = [&](const std::string& mode, const migraphx::operation& op) { auto test_rewrite_pooling = [&](const migraphx::op::pooling_mode mode,
const migraphx::operation& op) {
migraphx::program p1 = pooling_program(mode); migraphx::program p1 = pooling_program(mode);
migraphx::program p2 = opt_program(op); migraphx::program p2 = opt_program(op);
p1.compile(migraphx::ref::target{}); p1.compile(migraphx::ref::target{});
...@@ -174,8 +179,10 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -174,8 +179,10 @@ TEST_CASE(literal_rewrite_pooling_test)
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}; };
test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}})); test_rewrite_pooling(migraphx::op::pooling_mode::max,
test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}})); migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite_pooling(migraphx::op::pooling_mode::average,
migraphx::make_op("reduce_mean", {{"axes", {1}}}));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
...@@ -462,14 +463,15 @@ TEST_CASE(conv_pooling_dot) ...@@ -462,14 +463,15 @@ TEST_CASE(conv_pooling_dot)
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling", auto ap =
{{"mode", "average"}, m1.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0, 0, 0}},
{"lengths", {7, 7}}, {"stride", {1, 1}},
{"ceil_mode", 0}}), {"lengths", {7, 7}},
a1); {"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
...@@ -508,14 +510,15 @@ TEST_CASE(conv_pooling_dot) ...@@ -508,14 +510,15 @@ TEST_CASE(conv_pooling_dot)
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling", auto ap =
{{"mode", "average"}, m2.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0, 0, 0}},
{"lengths", {7, 7}}, {"stride", {1, 1}},
{"ceil_mode", 0}}), {"lengths", {7, 7}},
a1); {"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
...@@ -564,16 +567,17 @@ TEST_CASE(mobilenet_snippet) ...@@ -564,16 +567,17 @@ TEST_CASE(mobilenet_snippet)
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling", auto ap =
{{"mode", "average"}, mm.add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0, 0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0, 0, 0}},
{"lengths", {7, 7}}, {"stride", {1, 1}},
{"ceil_mode", 0}}), {"lengths", {7, 7}},
d6); {"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero); auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero); auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7); auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
......
...@@ -649,8 +649,8 @@ TEST_CASE(pooling_test) ...@@ -649,8 +649,8 @@ TEST_CASE(pooling_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"}; migraphx::op::pooling avg_pool_op{migraphx::op::pooling_mode::average};
migraphx::op::pooling max_pool_op{"max"}; migraphx::op::pooling max_pool_op{migraphx::op::pooling_mode::max};
avg_pool_op.stride = {2, 2}; avg_pool_op.stride = {2, 2};
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
......
...@@ -12,7 +12,7 @@ struct test_avg_pooling_1d : verify_program<test_avg_pooling_1d> ...@@ -12,7 +12,7 @@ struct test_avg_pooling_1d : verify_program<test_avg_pooling_1d>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}});
auto op = migraphx::op::pooling{"average", {0}, {1}, {3}}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average, {0}, {1}, {3}};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d> ...@@ -12,7 +12,8 @@ struct test_avg_pooling_3d : verify_program<test_avg_pooling_3d>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}}; auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -12,7 +12,8 @@ struct test_avg_pooling_3d_opt : verify_program<test_avg_pooling_3d_opt> ...@@ -12,7 +12,8 @@ struct test_avg_pooling_3d_opt : verify_program<test_avg_pooling_3d_opt>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 2, 3, 3, 3}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 2, 3, 3, 3}});
auto op = migraphx::op::pooling{"average", {0, 0, 0}, {1, 1, 1}, {3, 3, 3}}; auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {0, 0, 0}, {1, 1, 1}, {3, 3, 3}};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -10,9 +10,11 @@ struct test_avg_pooling_ceil_3d : verify_program<test_avg_pooling_ceil_3d> ...@@ -10,9 +10,11 @@ struct test_avg_pooling_ceil_3d : verify_program<test_avg_pooling_ceil_3d>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 5, 5, 5}});
auto op = migraphx::op::pooling{"average", {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true}; auto op = migraphx::op::pooling{
migraphx::op::pooling_mode::average, {1, 1, 1}, {3, 3, 3}, {3, 3, 3}, true};
mm->add_instruction(op, input); mm->add_instruction(op, input);
return p; return p;
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
struct test_concat_pooling : verify_program<test_concat_pooling> struct test_concat_pooling : verify_program<test_concat_pooling>
{ {
...@@ -18,12 +19,13 @@ struct test_concat_pooling : verify_program<test_concat_pooling> ...@@ -18,12 +19,13 @@ struct test_concat_pooling : verify_program<test_concat_pooling>
auto concat_t = mm->add_instruction( auto concat_t = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), concat); migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), concat);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", auto pooling =
{{"mode", "average"}, mm->add_instruction(migraphx::make_op("pooling",
{"padding", {0, 0}}, {{"mode", migraphx::op::pooling_mode::average},
{"stride", {1, 1}}, {"padding", {0, 0}},
{"lengths", {8, 8}}}), {"stride", {1, 1}},
concat_t); {"lengths", {8, 8}}}),
concat_t);
mm->add_instruction(migraphx::make_op("relu"), pooling); mm->add_instruction(migraphx::make_op("relu"), pooling);
return p; return p;
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
...@@ -29,7 +30,7 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -29,7 +30,7 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance); migraphx::make_op("batch_norm_inference"), conv, scale, bias, mean, variance);
auto relu = mm->add_instruction(migraphx::make_op("relu"), bn); auto relu = mm->add_instruction(migraphx::make_op("relu"), bn);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}}, {"padding", {1, 1}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {3, 3}}}), {"lengths", {3, 3}}}),
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
...@@ -47,7 +48,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -47,7 +48,7 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2); auto add = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
auto relu = mm->add_instruction(migraphx::make_op("relu"), add); auto relu = mm->add_instruction(migraphx::make_op("relu"), add);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1}}, {"padding", {1, 1}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {3, 3}}}), {"lengths", {3, 3}}}),
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_pooling : verify_program<test_conv_pooling> struct test_conv_pooling : verify_program<test_conv_pooling>
{ {
...@@ -15,7 +16,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling> ...@@ -15,7 +16,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), conv); auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
mm->add_instruction(migraphx::make_op("relu"), pooling); mm->add_instruction(migraphx::make_op("relu"), pooling);
return p; return p;
} }
......
...@@ -13,7 +13,7 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling> ...@@ -13,7 +13,7 @@ struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"average"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
mm->add_instruction(op, input); mm->add_instruction(op, input);
......
...@@ -13,7 +13,7 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling> ...@@ -13,7 +13,7 @@ struct test_global_max_pooling : verify_program<test_global_max_pooling>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto op = migraphx::op::pooling{"max"}; auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max};
auto lens = input->get_shape().lens(); auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
mm->add_instruction(op, input); mm->add_instruction(op, input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment