Commit 711aff2b authored by Khalique's avatar Khalique
Browse files

revert to array

parent 051338be
......@@ -182,7 +182,7 @@ struct pooling
std::string mode = "average";
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::vector<std::size_t> lengths = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
template <class Self, class F>
static auto reflect(Self& self, F f)
......
......@@ -150,14 +150,11 @@ struct onnx_parser
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::pooling op{name == "MaxPool" or name == "GlobalMaxPool" ? "max" : "average"};
if(name == "GlobalMaxPool" or name == "GlobalAveragePool")
{
auto lens = args.front()->get_shape().lens();
auto num_lengths = lens.size() - 2; // ignore N and C values in lens
assert(num_lengths > 0);
op.lengths = std::vector<std::size_t>(num_lengths);
std::copy_n(lens.begin() + 2, num_lengths, op.lengths.begin());
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
op.lengths = {lens[2], lens[3]};
}
if(contains(attributes, "pads"))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment