Commit fd0f647e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug in the unary and binary compute implementation

parent 970ac115
...@@ -33,14 +33,16 @@ struct binary : op_name<Derived> ...@@ -33,14 +33,16 @@ struct binary : op_name<Derived>
if(s1 == s2 and s1.packed()) if(s1 == s2 and s1.packed())
{ {
shape std_shape{s1.type(), s1.lens()}; shape std_shape{s1.type(), s1.lens()};
auto input1 = make_view(std_shape, args[0].data()); argument std_result{std_shape, result.data()};
auto input2 = make_view(std_shape, args[1].data()); argument std_arg0{std_shape, args[0].data()};
auto output = make_view(std_shape, result.data()); argument std_arg1{std_shape, args[1].data()};
std::transform(input1.begin(), visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
input1.end(), std::transform(input1.begin(),
input2.begin(), input1.end(),
output.begin(), input2.begin(),
static_cast<const Derived&>(*this).apply()); output.begin(),
static_cast<const Derived&>(*this).apply());
});
} }
else else
{ {
......
...@@ -32,12 +32,17 @@ struct unary : op_name<Derived> ...@@ -32,12 +32,17 @@ struct unary : op_name<Derived>
{ {
shape std_in_shape{in_shape.type(), in_shape.lens()}; shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()}; shape std_out_shape{output_shape.type(), output_shape.lens()};
auto input = make_view(std_in_shape, args[0].cast()); argument arg_in{std_in_shape, args[0].data()};
auto output = make_view(std_out_shape, result.cast()); argument arg_out{std_out_shape, result.data()};
std::transform(input.begin(), arg_out.visit([&](auto output) {
input.end(), arg_in.visit([&](auto input) {
output.begin(), std::transform(input.begin(),
static_cast<const Derived&>(*this).apply()); input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
} }
else else
{ {
...@@ -47,8 +52,6 @@ struct unary : op_name<Derived> ...@@ -47,8 +52,6 @@ struct unary : op_name<Derived>
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end())); input(idx.begin(), idx.end()));
}); });
return result;
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment