Commit 4d0fdcd5 authored by Paul's avatar Paul
Browse files

Add flatten operator

parent a29f3d94
...@@ -422,9 +422,29 @@ struct neg : unary ...@@ -422,9 +422,29 @@ struct neg : unary
struct flatten struct flatten
{ {
uint64_t axis = 0;
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if(axis == 0)
{
return {inputs.at(0).type(), {1, inputs.at(0).elements()}};
}
if(axis == 1)
{
return {inputs.at(0).type(), {inputs.at(0).elements(), 1}};
}
else
{
MIGRAPH_THROW("axis for flatten can only be either 0 or 1");
}
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
}; };
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment