Unverified Commit a2e90b5d authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

Mode as enum for pooling and roi_align (#1091)

Changed the pooling values for two structures from strings to specialized enum classes. Many test and operator parsing changes to support this. Introduces one new source file, op_enums.cpp.
parent d9d17a11
...@@ -38,6 +38,7 @@ add_library(migraphx ...@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp msgpack.cpp
normalize_attributes.cpp normalize_attributes.cpp
normalize_ops.cpp normalize_ops.cpp
op_enums.cpp
operation.cpp operation.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
......
...@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu19; migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18); auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20; migraphx::op::pooling pooling20;
pooling20.mode = "max"; pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0}; pooling20.padding = {0, 0};
pooling20.stride = {2, 2}; pooling20.stride = {2, 2};
pooling20.lengths = {3, 3}; pooling20.lengths = {3, 3};
...@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu24; migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23); auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25; migraphx::op::pooling pooling25;
pooling25.mode = "max"; pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0}; pooling25.padding = {0, 0};
pooling25.stride = {2, 2}; pooling25.stride = {2, 2};
pooling25.lengths = {3, 3}; pooling25.lengths = {3, 3};
...@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu37; migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36); auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38; migraphx::op::pooling pooling38;
pooling38.mode = "max"; pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0}; pooling38.padding = {0, 0};
pooling38.stride = {2, 2}; pooling38.stride = {2, 2};
pooling38.lengths = {3, 3}; pooling38.lengths = {3, 3};
......
...@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492; migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491); auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493; migraphx::op::pooling pooling493;
pooling493.mode = "max"; pooling493.mode = migraphx::op::pooling_mode::max;
pooling493.padding = {0, 0}; pooling493.padding = {0, 0};
pooling493.stride = {2, 2}; pooling493.stride = {2, 2};
pooling493.lengths = {3, 3}; pooling493.lengths = {3, 3};
...@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499; migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498); auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500; migraphx::op::pooling pooling500;
pooling500.mode = "max"; pooling500.mode = migraphx::op::pooling_mode::max;
pooling500.padding = {0, 0}; pooling500.padding = {0, 0};
pooling500.stride = {2, 2}; pooling500.stride = {2, 2};
pooling500.lengths = {3, 3}; pooling500.lengths = {3, 3};
...@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518; migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517); auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519; migraphx::op::pooling pooling519;
pooling519.mode = "average"; pooling519.mode = migraphx::op::pooling_mode::average;
pooling519.padding = {1, 1}; pooling519.padding = {1, 1};
pooling519.stride = {1, 1}; pooling519.stride = {1, 1};
pooling519.lengths = {3, 3}; pooling519.lengths = {3, 3};
...@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541; migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540); auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542; migraphx::op::pooling pooling542;
pooling542.mode = "average"; pooling542.mode = migraphx::op::pooling_mode::average;
pooling542.padding = {1, 1}; pooling542.padding = {1, 1};
pooling542.stride = {1, 1}; pooling542.stride = {1, 1};
pooling542.lengths = {3, 3}; pooling542.lengths = {3, 3};
...@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564; migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563); auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565; migraphx::op::pooling pooling565;
pooling565.mode = "average"; pooling565.mode = migraphx::op::pooling_mode::average;
pooling565.padding = {1, 1}; pooling565.padding = {1, 1};
pooling565.stride = {1, 1}; pooling565.stride = {1, 1};
pooling565.lengths = {3, 3}; pooling565.lengths = {3, 3};
...@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581; migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580); auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582; migraphx::op::pooling pooling582;
pooling582.mode = "max"; pooling582.mode = migraphx::op::pooling_mode::max;
pooling582.padding = {0, 0}; pooling582.padding = {0, 0};
pooling582.stride = {2, 2}; pooling582.stride = {2, 2};
pooling582.lengths = {3, 3}; pooling582.lengths = {3, 3};
...@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610; migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609); auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611; migraphx::op::pooling pooling611;
pooling611.mode = "average"; pooling611.mode = migraphx::op::pooling_mode::average;
pooling611.padding = {1, 1}; pooling611.padding = {1, 1};
pooling611.stride = {1, 1}; pooling611.stride = {1, 1};
pooling611.lengths = {3, 3}; pooling611.lengths = {3, 3};
...@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642; migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641); auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643; migraphx::op::pooling pooling643;
pooling643.mode = "average"; pooling643.mode = migraphx::op::pooling_mode::average;
pooling643.padding = {1, 1}; pooling643.padding = {1, 1};
pooling643.stride = {1, 1}; pooling643.stride = {1, 1};
pooling643.lengths = {3, 3}; pooling643.lengths = {3, 3};
...@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674; migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673); auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675; migraphx::op::pooling pooling675;
pooling675.mode = "average"; pooling675.mode = migraphx::op::pooling_mode::average;
pooling675.padding = {1, 1}; pooling675.padding = {1, 1};
pooling675.stride = {1, 1}; pooling675.stride = {1, 1};
pooling675.lengths = {3, 3}; pooling675.lengths = {3, 3};
...@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706; migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705); auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707; migraphx::op::pooling pooling707;
pooling707.mode = "average"; pooling707.mode = migraphx::op::pooling_mode::average;
pooling707.padding = {1, 1}; pooling707.padding = {1, 1};
pooling707.stride = {1, 1}; pooling707.stride = {1, 1};
pooling707.lengths = {3, 3}; pooling707.lengths = {3, 3};
...@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729; migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728); auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730; migraphx::op::pooling pooling730;
pooling730.mode = "max"; pooling730.mode = migraphx::op::pooling_mode::max;
pooling730.padding = {0, 0}; pooling730.padding = {0, 0};
pooling730.stride = {2, 2}; pooling730.stride = {2, 2};
pooling730.lengths = {3, 3}; pooling730.lengths = {3, 3};
...@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1; concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756); auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758; migraphx::op::pooling pooling758;
pooling758.mode = "average"; pooling758.mode = migraphx::op::pooling_mode::average;
pooling758.padding = {1, 1}; pooling758.padding = {1, 1};
pooling758.stride = {1, 1}; pooling758.stride = {1, 1};
pooling758.lengths = {3, 3}; pooling758.lengths = {3, 3};
...@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1; concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787); auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789; migraphx::op::pooling pooling789;
pooling789.mode = "average"; pooling789.mode = migraphx::op::pooling_mode::average;
pooling789.padding = {1, 1}; pooling789.padding = {1, 1};
pooling789.stride = {1, 1}; pooling789.stride = {1, 1};
pooling789.lengths = {3, 3}; pooling789.lengths = {3, 3};
...@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1; concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792); auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794; migraphx::op::pooling pooling794;
pooling794.mode = "average"; pooling794.mode = migraphx::op::pooling_mode::average;
pooling794.padding = {0, 0}; pooling794.padding = {0, 0};
pooling794.stride = {8, 8}; pooling794.stride = {8, 8};
pooling794.lengths = {8, 8}; pooling794.lengths = {8, 8};
......
...@@ -87,6 +87,6 @@ target get_target(bool gpu) ...@@ -87,6 +87,6 @@ target get_target(bool gpu)
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); } void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu269; migraphx::op::relu relu269;
auto mx269 = mm->add_instruction(relu269, mx268); auto mx269 = mm->add_instruction(relu269, mx268);
migraphx::op::pooling pooling270; migraphx::op::pooling pooling270;
pooling270.mode = "max"; pooling270.mode = migraphx::op::pooling_mode::max;
pooling270.padding = {1, 1}; pooling270.padding = {1, 1};
pooling270.stride = {2, 2}; pooling270.stride = {2, 2};
pooling270.lengths = {3, 3}; pooling270.lengths = {3, 3};
...@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu438; migraphx::op::relu relu438;
auto mx438 = mm->add_instruction(relu438, mx437); auto mx438 = mm->add_instruction(relu438, mx437);
migraphx::op::pooling pooling439; migraphx::op::pooling pooling439;
pooling439.mode = "average"; pooling439.mode = migraphx::op::pooling_mode::average;
pooling439.padding = {0, 0}; pooling439.padding = {0, 0};
pooling439.stride = {1, 1}; pooling439.stride = {1, 1};
pooling439.lengths = {7, 7}; pooling439.lengths = {7, 7};
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
...@@ -15,6 +17,14 @@ enum padding_mode_t ...@@ -15,6 +17,14 @@ enum padding_mode_t
valid valid
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max
};
// indicate rnn computation direction // indicate rnn computation direction
enum class rnn_direction enum class rnn_direction
{ {
...@@ -23,6 +33,7 @@ enum class rnn_direction ...@@ -23,6 +33,7 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v); std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct pooling struct pooling
{ {
std::string mode = "average"; pooling_mode mode = {pooling_mode::average};
std::vector<std::size_t> padding = {0, 0}; std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1}; std::vector<std::size_t> lengths = {1, 1};
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <limits> #include <limits>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
...@@ -21,7 +22,7 @@ namespace op { ...@@ -21,7 +22,7 @@ namespace op {
struct roialign struct roialign
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode = "half_pixel";
std::string mode = "avg"; pooling_mode mode = {pooling_mode::average};
int64_t output_height = 1; int64_t output_height = 1;
int64_t output_width = 1; int64_t output_width = 1;
int64_t sampling_ratio = 0; int64_t sampling_ratio = 0;
...@@ -241,16 +242,17 @@ struct roialign ...@@ -241,16 +242,17 @@ struct roialign
in_dims[0] * in_dims[1]); in_dims[0] * in_dims[1]);
double output_val; double output_val;
std::tie(output_val, vec_index[c]) = std::tie(output_val, vec_index[c]) =
(mode == "avg") ? this->calc_pooling(offset_bottom_data, (mode == migraphx::op::pooling_mode::average)
bin_grid_size, ? this->calc_pooling(offset_bottom_data,
pre_calc, bin_grid_size,
vec_index[c], pre_calc,
avg_pool{}) vec_index[c],
: this->calc_pooling(offset_bottom_data, avg_pool{})
bin_grid_size, : this->calc_pooling(offset_bottom_data,
pre_calc, bin_grid_size,
vec_index[c], pre_calc,
max_pool{}); vec_index[c],
max_pool{});
output(n, c, ph, pw) = output_val; output(n, c, ph, pw) = output_val;
}); });
}); });
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v, ...@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start) std::vector<int64_t>& s_start)
{ {
// maxpooling or count_include_pad is 1, no change is required. // maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1) if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{ {
return; return;
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp> #include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -27,10 +28,16 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -27,10 +28,16 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
std::string mode = opd.op_name; std::string mode = opd.op_name;
operation op = make_op("pooling", {{"mode", mode}}); if(mode != "max" && mode != "average")
value values = op.to_value(); {
auto l0 = args[0]; MIGRAPHX_THROW("onnx pooling mode must be \"max\" or \"average\"");
auto in_lens = l0->get_shape().lens(); }
operation op = make_op(
"pooling",
{{"mode", mode == "average" ? op::pooling_mode::average : op::pooling_mode::max}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2); assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2; auto kdims = in_lens.size() - 2;
...@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
values["padding"].clear(); values["padding"].clear();
......
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign>
"\": invalid value!"); "\": invalid value!");
} }
std::string mode = "avg"; migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
mode = info.attributes.at("mode").s(); // read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
} }
int64_t output_height = 1; int64_t output_height = 1;
...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign>
{ {
spatial_scale = info.attributes.at("spatial_scale").f(); spatial_scale = info.attributes.at("spatial_scale").f();
} }
return info.add_instruction(make_op("roialign", return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode}, {{"coordinate_transformation_mode", coord_trans_mode},
{"mode", mode}, {"mode", rmode},
{"output_height", output_height}, {"output_height", output_height},
{"output_width", output_width}, {"output_width", output_width},
{"sampling_ratio", sampling_ratio}, {"sampling_ratio", sampling_ratio},
......
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const ...@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
pooling = pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
......
...@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog, ...@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
return hs_padded; return hs_padded;
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -460,10 +460,10 @@ struct cpu_apply ...@@ -460,10 +460,10 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>()) not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value())); return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>(); op::pooling_mode mode = v["mode"].to<op::pooling_mode>();
if(mode == "max") if(mode == op::pooling_mode::max)
return replace(ins, make_op("cpu::pooling_max", v)); return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average") else if(mode == op::pooling_mode::average)
return replace(ins, make_op("cpu::pooling_average", v)); return replace(ins, make_op("cpu::pooling_average", v));
return ins; return ins;
} }
......
...@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg; auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
...@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const ...@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
// pooling_mode // pooling_mode
assert(val.contains("mode")); assert(val.contains("mode"));
auto mode = val.at("mode").to<std::string>(); auto mode = val.at("mode").to<migraphx::op::pooling_mode>();
bool is_avg_pooling = (mode == "avg"); bool is_avg_pooling = (mode == migraphx::op::pooling_mode::average);
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling)); options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode // coord_trans_mode
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <sstream>
#ifdef HAS_FIND_MODE_API #ifdef HAS_FIND_MODE_API
extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc, extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc,
int findMode); int findMode);
...@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op) ...@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op)
inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
{ {
miopenPoolingMode_t mode; miopenPoolingMode_t mode;
if(op.mode == "max") if(op.mode == op::pooling_mode::max)
mode = miopenPoolingMax; mode = miopenPoolingMax;
else if(op.mode == "average") else if(op.mode == op::pooling_mode::average)
mode = miopenPoolingAverage; mode = miopenPoolingAverage;
else else
MIGRAPHX_THROW("Unknown mode for pooling: " + op.mode); {
std::stringstream ss("Unknown mode for pooling: ");
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor); auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims(); int kdims = op.kdims();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment