Commit 19bc41fc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the return shape to make it consistent in all places.

parent d4c89abc
......@@ -21,7 +21,14 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
if (inputs.at(0) == inputs.at(1) and inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
};
......
......@@ -21,7 +21,15 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
auto s = inputs.at(0);
if (s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment