"vscode:/vscode.git/clone" did not exist on "e59e3058190cb7d7a590a3ed8cb6ca189f198799"
Commit 1110ef29 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added flatten to ONNX

parent 39151d27
......@@ -422,7 +422,28 @@ struct neg : unary
struct flatten
{
uint64_t axis = 0;
std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if (axis == 0)
{
return {inputs.at(0).type(), {1, inputs.at(0).elements()}};
}
if (axis == 1)
{
return {inputs.at(0).type(), {inputs.at(0).elements(), 1}};
}
else
{
MIGRAPH_THROW("axis can only be either 0 or 1");
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct broadcast
......
......@@ -62,6 +62,7 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
}
......@@ -161,6 +162,17 @@ struct onnx_parser
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(flatten{axis}, args[0]);
}
instruction_ref
parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment