Commit 042a9437 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup

parent ef5d7092
...@@ -13,13 +13,15 @@ struct binary : op_name<Derived> ...@@ -13,13 +13,15 @@ struct binary : op_name<Derived>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed()) auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
return {inputs.at(0).type(), inputs.at(0).lens()}; return {s0.type(), s0.lens()};
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -593,13 +593,15 @@ struct cpu_unary ...@@ -593,13 +593,15 @@ struct cpu_unary
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
if(inputs.at(0).packed()) check_shapes{inputs}.has(1);
auto s = inputs.at(0);
if(s.packed())
{ {
return inputs.at(0); return s;
} }
else else
{ {
return {inputs.at(0).type(), inputs.at(0).lens()}; return {s.type(), s.lens()};
} }
} }
...@@ -793,13 +795,16 @@ struct cpu_binary ...@@ -793,13 +795,16 @@ struct cpu_binary
std::string name() const { return "cpu::" + op.name(); } std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed()) check_shapes{inputs}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
return {inputs.at(0).type(), inputs.at(0).lens()}; return {s0.type(), s0.lens()};
} }
} }
......
...@@ -45,13 +45,14 @@ struct unary_device : oper<Derived> ...@@ -45,13 +45,14 @@ struct unary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
if(inputs.at(0).packed()) auto s = inputs.at(0);
if(s.packed())
{ {
return inputs.at(0); return s;
} }
else else
{ {
return {inputs.at(0).type(), inputs.at(0).lens()}; return {s.type(), s.lens()};
} }
} }
...@@ -73,13 +74,15 @@ struct binary_device : oper<Derived> ...@@ -73,13 +74,15 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed()) auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
return {inputs.at(0).type(), inputs.at(0).lens()}; return {s0.type(), s0.lens()};
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment