Commit d4d2335a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from pr propogate_const

parents 2820afff a81af777
...@@ -13,11 +13,16 @@ struct binary : op_name<Derived> ...@@ -13,11 +13,16 @@ struct binary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
const auto& s = inputs.front(); if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed())
if(s.scalar() and s.elements() == 1) {
return {s.type()}; return inputs.at(0);
return {s.type(), s.lens()}; }
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
......
...@@ -13,7 +13,15 @@ struct unary : op_name<Derived> ...@@ -13,7 +13,15 @@ struct unary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); auto s = inputs.at(0);
if(s.packed())
{
return s;
}
else
{
return {s.type(), s.lens()};
}
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment