Commit 3bf4c1ff authored by Paul's avatar Paul
Browse files

Formatting

parent 246c4236
...@@ -94,12 +94,10 @@ constexpr void each_args(F) ...@@ -94,12 +94,10 @@ constexpr void each_args(F)
{ {
} }
template<class F, class T> template <class F, class T>
auto unpack(F f, T& x) auto unpack(F f, T& x)
{ {
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
f(std::get<is>(x)...);
});
} }
/// Implements a fix-point combinator /// Implements a fix-point combinator
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template<class... Ts> template <class... Ts>
auto par_dfor(Ts... xs) auto par_dfor(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) {
...@@ -17,12 +17,16 @@ auto par_dfor(Ts... xs) ...@@ -17,12 +17,16 @@ auto par_dfor(Ts... xs)
array_type lens = {{static_cast<std::size_t>(xs)...}}; array_type lens = {{static_cast<std::size_t>(xs)...}};
auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{}); auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{});
const std::size_t min_grain = 8; const std::size_t min_grain = 8;
if (n > 2*min_grain) { if(n > 2 * min_grain)
{
array_type strides; array_type strides;
strides.fill(1); strides.fill(1);
std::partial_sum( std::partial_sum(lens.rbegin(),
lens.rbegin(), lens.rend() - 1, strides.rbegin() + 1, std::multiplies<std::size_t>()); lens.rend() - 1,
auto size = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>()); strides.rbegin() + 1,
std::multiplies<std::size_t>());
auto size =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
par_for(size, min_grain, [&](std::size_t i) { par_for(size, min_grain, [&](std::size_t i) {
array_type indices; array_type indices;
std::transform(strides.begin(), std::transform(strides.begin(),
...@@ -32,7 +36,9 @@ auto par_dfor(Ts... xs) ...@@ -32,7 +36,9 @@ auto par_dfor(Ts... xs)
[&](size_t stride, size_t len) { return (i / stride) % len; }); [&](size_t stride, size_t len) { return (i / stride) % len; });
migraphx::unpack(f, indices); migraphx::unpack(f, indices);
}); });
} else { }
else
{
dfor(xs...)(f); dfor(xs...)(f);
} }
......
...@@ -41,9 +41,7 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -41,9 +41,7 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
const std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size()); const std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0; std::size_t work = 0;
std::generate(threads.begin(), std::generate(threads.begin(), threads.end(), [=, &work] {
threads.end(),
[=, &work] {
auto result = joinable_thread([=] { auto result = joinable_thread([=] {
std::size_t start = work; std::size_t start = work;
std::size_t last = std::min(n, work + grainsize); std::size_t last = std::min(n, work + grainsize);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment