"test/vscode:/vscode.git/clone" did not exist on "7353ec0c25468d754ad5dd786e979a3bbade0a47"
Commit 1110ef29 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added flatten to ONNX

parent 39151d27
...@@ -422,7 +422,28 @@ struct neg : unary ...@@ -422,7 +422,28 @@ struct neg : unary
struct flatten struct flatten
{ {
uint64_t axis = 0;
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if (axis == 0)
{
return {inputs.at(0).type(), {1, inputs.at(0).elements()}};
}
if (axis == 1)
{
return {inputs.at(0).type(), {inputs.at(0).elements(), 1}};
}
else
{
MIGRAPH_THROW("axis can only be either 0 or 1");
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
}; };
struct broadcast struct broadcast
......
...@@ -62,6 +62,7 @@ struct onnx_parser ...@@ -62,6 +62,7 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape); add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
} }
...@@ -161,6 +162,17 @@ struct onnx_parser ...@@ -161,6 +162,17 @@ struct onnx_parser
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
instruction_ref
parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(flatten{axis}, args[0]);
}
instruction_ref instruction_ref
parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>) parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment