par_dfor.hpp 1.56 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_DFOR_HPP
Paul's avatar
Paul committed
3
4
5
6
7
8
9
10
11

#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <array>
#include <numeric>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
12
template <class... Ts>
Paul's avatar
Paul committed
13
14
15
16
17
18
19
auto par_dfor(Ts... xs)
{
    return [=](auto f) {
        using array_type = std::array<std::size_t, sizeof...(Ts)>;
        array_type lens  = {{static_cast<std::size_t>(xs)...}};
        auto n = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>{});
        const std::size_t min_grain = 8;
Paul's avatar
Paul committed
20
21
        if(n > 2 * min_grain)
        {
Paul's avatar
Paul committed
22
23
            array_type strides;
            strides.fill(1);
Paul's avatar
Paul committed
24
25
26
27
28
29
            std::partial_sum(lens.rbegin(),
                             lens.rend() - 1,
                             strides.rbegin() + 1,
                             std::multiplies<std::size_t>());
            auto size =
                std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
Paul's avatar
Paul committed
30
31
32
33
34
35
36
37
38
            par_for(size, min_grain, [&](std::size_t i) {
                array_type indices;
                std::transform(strides.begin(),
                               strides.end(),
                               lens.begin(),
                               indices.begin(),
                               [&](size_t stride, size_t len) { return (i / stride) % len; });
                migraphx::unpack(f, indices);
            });
Paul's avatar
Paul committed
39
40
41
        }
        else
        {
Paul's avatar
Paul committed
42
43
44
45
46
47
48
49
50
51
            dfor(xs...)(f);
        }

    };
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif