Unverified Commit a2e90b5d authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

Mode as enum for pooling and roi_align (#1091)

Changed the pooling values for two structures from strings to specialized enum classes. Many test and operator parsing changes to support this. Introduces one new source file, op_enums.cpp.
parent d9d17a11
......@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp
normalize_attributes.cpp
normalize_ops.cpp
op_enums.cpp
operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
......
......@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = "max";
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
......@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = "max";
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
......@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = "max";
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
......
......@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493;
pooling493.mode = "max";
pooling493.mode = migraphx::op::pooling_mode::max;
pooling493.padding = {0, 0};
pooling493.stride = {2, 2};
pooling493.lengths = {3, 3};
......@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500;
pooling500.mode = "max";
pooling500.mode = migraphx::op::pooling_mode::max;
pooling500.padding = {0, 0};
pooling500.stride = {2, 2};
pooling500.lengths = {3, 3};
......@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519;
pooling519.mode = "average";
pooling519.mode = migraphx::op::pooling_mode::average;
pooling519.padding = {1, 1};
pooling519.stride = {1, 1};
pooling519.lengths = {3, 3};
......@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542;
pooling542.mode = "average";
pooling542.mode = migraphx::op::pooling_mode::average;
pooling542.padding = {1, 1};
pooling542.stride = {1, 1};
pooling542.lengths = {3, 3};
......@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565;
pooling565.mode = "average";
pooling565.mode = migraphx::op::pooling_mode::average;
pooling565.padding = {1, 1};
pooling565.stride = {1, 1};
pooling565.lengths = {3, 3};
......@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582;
pooling582.mode = "max";
pooling582.mode = migraphx::op::pooling_mode::max;
pooling582.padding = {0, 0};
pooling582.stride = {2, 2};
pooling582.lengths = {3, 3};
......@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611;
pooling611.mode = "average";
pooling611.mode = migraphx::op::pooling_mode::average;
pooling611.padding = {1, 1};
pooling611.stride = {1, 1};
pooling611.lengths = {3, 3};
......@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643;
pooling643.mode = "average";
pooling643.mode = migraphx::op::pooling_mode::average;
pooling643.padding = {1, 1};
pooling643.stride = {1, 1};
pooling643.lengths = {3, 3};
......@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675;
pooling675.mode = "average";
pooling675.mode = migraphx::op::pooling_mode::average;
pooling675.padding = {1, 1};
pooling675.stride = {1, 1};
pooling675.lengths = {3, 3};
......@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707;
pooling707.mode = "average";
pooling707.mode = migraphx::op::pooling_mode::average;
pooling707.padding = {1, 1};
pooling707.stride = {1, 1};
pooling707.lengths = {3, 3};
......@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730;
pooling730.mode = "max";
pooling730.mode = migraphx::op::pooling_mode::max;
pooling730.padding = {0, 0};
pooling730.stride = {2, 2};
pooling730.lengths = {3, 3};
......@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758;
pooling758.mode = "average";
pooling758.mode = migraphx::op::pooling_mode::average;
pooling758.padding = {1, 1};
pooling758.stride = {1, 1};
pooling758.lengths = {3, 3};
......@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789;
pooling789.mode = "average";
pooling789.mode = migraphx::op::pooling_mode::average;
pooling789.padding = {1, 1};
pooling789.stride = {1, 1};
pooling789.lengths = {3, 3};
......@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794;
pooling794.mode = "average";
pooling794.mode = migraphx::op::pooling_mode::average;
pooling794.padding = {0, 0};
pooling794.stride = {8, 8};
pooling794.lengths = {8, 8};
......
......@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu269;
auto mx269 = mm->add_instruction(relu269, mx268);
migraphx::op::pooling pooling270;
pooling270.mode = "max";
pooling270.mode = migraphx::op::pooling_mode::max;
pooling270.padding = {1, 1};
pooling270.stride = {2, 2};
pooling270.lengths = {3, 3};
......@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu438;
auto mx438 = mm->add_instruction(relu438, mx437);
migraphx::op::pooling pooling439;
pooling439.mode = "average";
pooling439.mode = migraphx::op::pooling_mode::average;
pooling439.padding = {0, 0};
pooling439.stride = {1, 1};
pooling439.lengths = {7, 7};
......
......@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
return;
}
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp>
#include <utility>
......@@ -15,6 +17,14 @@ enum padding_mode_t
valid
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max
};
// indicate rnn computation direction
enum class rnn_direction
{
......@@ -23,6 +33,7 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
......
......@@ -16,12 +16,13 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pooling
{
std::string mode = "average";
pooling_mode mode = {pooling_mode::average};
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1};
......
......@@ -3,6 +3,7 @@
#include <limits>
#include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
......@@ -21,7 +22,7 @@ namespace op {
struct roialign
{
std::string coord_trans_mode = "half_pixel";
std::string mode = "avg";
pooling_mode mode = {pooling_mode::average};
int64_t output_height = 1;
int64_t output_width = 1;
int64_t sampling_ratio = 0;
......@@ -241,7 +242,8 @@ struct roialign
in_dims[0] * in_dims[1]);
double output_val;
std::tie(output_val, vec_index[c]) =
(mode == "avg") ? this->calc_pooling(offset_bottom_data,
(mode == migraphx::op::pooling_mode::average)
? this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
......
......@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
return;
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start)
{
// maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{
return;
}
......
......@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......@@ -27,7 +28,13 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
operation op = make_op("pooling", {{"mode", mode}});
if(mode != "max" && mode != "average")
{
MIGRAPHX_THROW("onnx pooling mode must be \"max\" or \"average\"");
}
operation op = make_op(
"pooling",
{{"mode", mode == "average" ? op::pooling_mode::average : op::pooling_mode::max}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
......@@ -72,6 +79,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
......
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
......@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign>
"\": invalid value!");
}
std::string mode = "avg";
migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode"))
{
mode = info.attributes.at("mode").s();
// read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
}
int64_t output_height = 1;
......@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign>
{
spatial_scale = info.attributes.at("spatial_scale").f();
}
return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode},
{"mode", mode},
{"mode", rmode},
{"output_height", output_height},
{"output_width", output_width},
{"sampling_ratio", sampling_ratio},
......
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -38,7 +38,7 @@ void rewrite_pooling::apply(module& prog) const
instruction_ref pooling{};
// average pooling
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
......
......@@ -1426,14 +1426,5 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
return hs_padded;
}
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -460,10 +460,10 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>();
if(mode == "max")
op::pooling_mode mode = v["mode"].to<op::pooling_mode>();
if(mode == op::pooling_mode::max)
return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average")
else if(mode == op::pooling_mode::average)
return replace(ins, make_op("cpu::pooling_average", v));
return ins;
}
......
......@@ -129,7 +129,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg;
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
......@@ -145,5 +146,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -59,8 +59,8 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
// pooling_mode
assert(val.contains("mode"));
auto mode = val.at("mode").to<std::string>();
bool is_avg_pooling = (mode == "avg");
auto mode = val.at("mode").to<migraphx::op::pooling_mode>();
bool is_avg_pooling = (mode == migraphx::op::pooling_mode::average);
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode
......
......@@ -9,6 +9,8 @@
#include <miopen/miopen.h>
#include <migraphx/config.hpp>
#include <sstream>
#ifdef HAS_FIND_MODE_API
extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc,
int findMode);
......@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op)
inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
{
miopenPoolingMode_t mode;
if(op.mode == "max")
if(op.mode == op::pooling_mode::max)
mode = miopenPoolingMax;
else if(op.mode == "average")
else if(op.mode == op::pooling_mode::average)
mode = miopenPoolingAverage;
else
MIGRAPHX_THROW("Unknown mode for pooling: " + op.mode);
{
std::stringstream ss("Unknown mode for pooling: ");
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment