Commit c7888300 authored by Paul's avatar Paul
Browse files

Formatting

parent a919e88a
......@@ -34,12 +34,10 @@ auto fix(F f)
return fix<void>(f);
}
template<class... Ts>
template <class... Ts>
auto make_sequence(Ts... xs)
{
return [=](auto f) {
return f(xs...);
};
return [=](auto f) { return f(xs...); };
}
} // namespace migraph
......
......@@ -18,7 +18,7 @@ auto nary(argument result, Arguments... args)
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
};
}
......@@ -28,18 +28,21 @@ auto nary_nonstandard(argument result, Arguments... args)
return [=](auto f) {
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), inputs.get_shape().strides()}, inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
inputs.get_shape().strides()},
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
});
});
});
};
}
......@@ -50,13 +53,10 @@ auto nary_standard(argument result, Arguments... args)
// assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...);
auto data = make_sequence(inputs.data()...);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... xps) {
outp[i] = f(xps[i]...);
});
});
gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
});
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment