Commit fd0f647e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug in the unary and binary compute implementation

parent 970ac115
......@@ -33,14 +33,16 @@ struct binary : op_name<Derived>
if(s1 == s2 and s1.packed())
{
shape std_shape{s1.type(), s1.lens()};
auto input1 = make_view(std_shape, args[0].data());
auto input2 = make_view(std_shape, args[1].data());
auto output = make_view(std_shape, result.data());
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}
else
{
......
......@@ -32,12 +32,17 @@ struct unary : op_name<Derived>
{
shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
auto input = make_view(std_in_shape, args[0].cast());
auto output = make_view(std_out_shape, result.cast());
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
});
}
else
{
......@@ -47,8 +52,6 @@ struct unary : op_name<Derived>
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
return result;
});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment