"...text-generation-inference-dcu.git" did not exist on "55e2c4968b020d42a55eadd7ef2eee80ba610768"
Commit 3bf4c1ff authored by Paul's avatar Paul
Browse files

Formatting

parent 246c4236
...@@ -94,12 +94,10 @@ constexpr void each_args(F) ...@@ -94,12 +94,10 @@ constexpr void each_args(F)
{ {
} }
template<class F, class T> template <class F, class T>
auto unpack(F f, T& x) auto unpack(F f, T& x)
{ {
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
f(std::get<is>(x)...);
});
} }
/// Implements a fix-point combinator /// Implements a fix-point combinator
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template<class... Ts> template <class... Ts>
auto par_dfor(Ts... xs) auto par_dfor(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) {
...@@ -17,12 +17,16 @@ auto par_dfor(Ts... xs) ...@@ -17,12 +17,16 @@ auto par_dfor(Ts... xs)
array_type lens = {{static_cast<std::size_t>(xs)...}}; array_type lens = {{static_cast<std::size_t>(xs)...}};
auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{}); auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{});
const std::size_t min_grain = 8; const std::size_t min_grain = 8;
if (n > 2*min_grain) { if(n > 2 * min_grain)
{
array_type strides; array_type strides;
strides.fill(1); strides.fill(1);
std::partial_sum( std::partial_sum(lens.rbegin(),
lens.rbegin(), lens.rend() - 1, strides.rbegin() + 1, std::multiplies<std::size_t>()); lens.rend() - 1,
auto size = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>()); strides.rbegin() + 1,
std::multiplies<std::size_t>());
auto size =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
par_for(size, min_grain, [&](std::size_t i) { par_for(size, min_grain, [&](std::size_t i) {
array_type indices; array_type indices;
std::transform(strides.begin(), std::transform(strides.begin(),
...@@ -32,7 +36,9 @@ auto par_dfor(Ts... xs) ...@@ -32,7 +36,9 @@ auto par_dfor(Ts... xs)
[&](size_t stride, size_t len) { return (i / stride) % len; }); [&](size_t stride, size_t len) { return (i / stride) % len; });
migraphx::unpack(f, indices); migraphx::unpack(f, indices);
}); });
} else { }
else
{
dfor(xs...)(f); dfor(xs...)(f);
} }
......
...@@ -41,20 +41,18 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -41,20 +41,18 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
const std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size()); const std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0; std::size_t work = 0;
std::generate(threads.begin(), std::generate(threads.begin(), threads.end(), [=, &work] {
threads.end(), auto result = joinable_thread([=] {
[=, &work] { std::size_t start = work;
auto result = joinable_thread([=] { std::size_t last = std::min(n, work + grainsize);
std::size_t start = work; for(std::size_t i = start; i < last; i++)
std::size_t last = std::min(n, work + grainsize); {
for(std::size_t i = start; i < last; i++) f(i);
{ }
f(i); });
} work += grainsize;
}); return result;
work += grainsize; });
return result;
});
assert(work >= n); assert(work >= n);
} }
} }
......
...@@ -124,9 +124,9 @@ struct cpu_convolution ...@@ -124,9 +124,9 @@ struct cpu_convolution
auto wei_w = wei[3]; auto wei_w = wei[3];
par_dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const int start_y = j * op.stride[1] - op.padding[1];
...@@ -247,9 +247,9 @@ struct cpu_pooling ...@@ -247,9 +247,9 @@ struct cpu_pooling
auto in_w = input.get_shape().lens()[3]; auto in_w = input.get_shape().lens()[3];
par_dfor(output_shape.lens()[0], par_dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x0 = i * op.stride[0] - op.padding[0]; const int start_x0 = i * op.stride[0] - op.padding[0];
const int start_y0 = j * op.stride[1] - op.padding[1]; const int start_y0 = j * op.stride[1] - op.padding[1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment