Commit 711aff2b authored by Khalique's avatar Khalique
Browse files

revert to array

parent 051338be
...@@ -182,7 +182,7 @@ struct pooling ...@@ -182,7 +182,7 @@ struct pooling
std::string mode = "average"; std::string mode = "average";
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::vector<std::size_t> lengths = {{1, 1}}; std::array<std::size_t, 2> lengths = {{1, 1}};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
...@@ -150,14 +150,11 @@ struct onnx_parser ...@@ -150,14 +150,11 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::pooling op{name == "MaxPool" or name == "GlobalMaxPool" ? "max" : "average"}; op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
if(name == "GlobalMaxPool" or name == "GlobalAveragePool") if(starts_with(name, "Global"))
{ {
auto lens = args.front()->get_shape().lens(); auto lens = args.front()->get_shape().lens();
auto num_lengths = lens.size() - 2; // ignore N and C values in lens op.lengths = {lens[2], lens[3]};
assert(num_lengths > 0);
op.lengths = std::vector<std::size_t>(num_lengths);
std::copy_n(lens.begin() + 2, num_lengths, op.lengths.begin());
} }
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment