Commit c7888300 authored by Paul's avatar Paul
Browse files

Formatting

parent a919e88a
...@@ -34,12 +34,10 @@ auto fix(F f) ...@@ -34,12 +34,10 @@ auto fix(F f)
return fix<void>(f); return fix<void>(f);
} }
template<class... Ts> template <class... Ts>
auto make_sequence(Ts... xs) auto make_sequence(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) { return f(xs...); };
return f(xs...);
};
} }
} // namespace migraph } // namespace migraph
......
...@@ -18,7 +18,7 @@ auto nary(argument result, Arguments... args) ...@@ -18,7 +18,7 @@ auto nary(argument result, Arguments... args)
nary_standard(result, args...)(f); nary_standard(result, args...)(f);
else else
nary_nonstandard(result, args...)(f); nary_nonstandard(result, args...)(f);
}; };
} }
...@@ -28,18 +28,21 @@ auto nary_nonstandard(argument result, Arguments... args) ...@@ -28,18 +28,21 @@ auto nary_nonstandard(argument result, Arguments... args)
return [=](auto f) { return [=](auto f) {
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), inputs.get_shape().strides()}, inputs.data())...); auto data = make_sequence(
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides()); std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
auto* outp = output.data(); inputs.get_shape().strides()},
gs_launch(output_shape.elements())([=](auto i) { inputs.data())...);
data([&](auto... ps) { hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto outidx = out_desc.multi(i); auto* outp = output.data();
outp[i] = f(ps.second[ps.first.linear(outidx)]...); gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
}); });
}); });
}); });
});
}; };
} }
...@@ -50,13 +53,10 @@ auto nary_standard(argument result, Arguments... args) ...@@ -50,13 +53,10 @@ auto nary_standard(argument result, Arguments... args)
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
auto output_shape = result.get_shape(); auto output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = make_sequence(inputs.data()...); auto data = make_sequence(inputs.data()...);
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) { gs_launch(output_shape.elements())(
data([&](auto... xps) { [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
outp[i] = f(xps[i]...);
});
});
}); });
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment