shape_for_each.hpp 1011 Bytes
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraph/shape.hpp>
5
#include <migraph/config.hpp>
Paul's avatar
Paul committed
6
7
#include <algorithm>

8
9
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
template <class F>
Paul's avatar
Paul committed
12
void shape_for_each(const migraph::shape& s, F f)
Paul's avatar
Paul committed
13
14
15
16
{
    // Ensure calls to f use const ref to vector
    auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
    std::vector<std::size_t> indices(s.lens().size());
Paul's avatar
Paul committed
17
18
    for(std::size_t i = 0; i < s.elements(); i++)
    {
Paul's avatar
Paul committed
19
        std::transform(s.strides().begin(),
Paul's avatar
Paul committed
20
21
22
                       s.strides().end(),
                       s.lens().begin(),
                       indices.begin(),
Paul's avatar
Paul committed
23
24
25
26
                       [&](std::size_t stride, std::size_t len) {
                           assert(len > 0 and stride > 0);
                           return (i / stride) % len;
                       });
Paul's avatar
Paul committed
27
28
29
30
        call(indices);
    }
}

31
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
32
} // namespace migraph
Paul's avatar
Paul committed
33
34

#endif