Commit 0856b6e2 authored by Paul's avatar Paul
Browse files

Fix flatten operator

parent 33212f8f
......@@ -427,17 +427,21 @@ struct flatten
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
auto&& lens = inputs.front().lens();
if(axis == 0)
{
return {inputs.at(0).type(), {1, inputs.at(0).elements()}};
}
if(axis == 1)
else if(axis < lens.size())
{
return {inputs.at(0).type(), {inputs.at(0).elements(), 1}};
auto x = std::accumulate(lens.begin(), lens.begin()+axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(lens.begin()+axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
}
else
{
MIGRAPH_THROW("axis for flatten can only be either 0 or 1");
MIGRAPH_THROW("axis for flatten must be less than tensor rank");
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
......
......@@ -60,8 +60,8 @@ struct onnx_parser
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_max_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_average_pooling);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
......@@ -129,28 +129,9 @@ struct onnx_parser
}
instruction_ref
parse_max_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args)
parse_pooling(std::string name, attribute_map attributes, std::vector<instruction_ref> args)
{
pooling op{"max"};
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, args);
}
instruction_ref
parse_average_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
pooling op{"average"};
pooling op{name == "MaxPool" ? "max" : "average"};
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
......@@ -187,10 +168,10 @@ struct onnx_parser
parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
// if(contains(attributes, "axis"))
// {
// axis = parse_value(attributes.at("axis")).at<int>();
// }
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(flatten{axis}, args[0]);
}
......
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/auto_contiguous.hpp>
namespace migraph {
namespace cpu {
std::string cpu_target::name() const { return "cpu"; }
std::vector<pass> cpu_target::get_passes(context&) const { return {cpu_lowering{}}; }
std::vector<pass> cpu_target::get_passes(context&) const { return {auto_contiguous{}, cpu_lowering{}}; }
} // namespace cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment