shape_for_each.hpp 911 Bytes
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
5
6
#include <algorithm>

Paul's avatar
Paul committed
7
namespace migraph {
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
template <class F>
Paul's avatar
Paul committed
10
void shape_for_each(const migraph::shape& s, F f)
Paul's avatar
Paul committed
11
12
13
14
{
    // Ensure calls to f use const ref to vector
    auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
    std::vector<std::size_t> indices(s.lens().size());
Paul's avatar
Paul committed
15
16
    for(std::size_t i = 0; i < s.elements(); i++)
    {
Paul's avatar
Paul committed
17
        std::transform(s.strides().begin(),
Paul's avatar
Paul committed
18
19
20
                       s.strides().end(),
                       s.lens().begin(),
                       indices.begin(),
Paul's avatar
Paul committed
21
22
23
24
                       [&](std::size_t stride, std::size_t len) {
                           assert(len > 0 and stride > 0);
                           return (i / stride) % len;
                       });
Paul's avatar
Paul committed
25
26
27
28
        call(indices);
    }
}

Paul's avatar
Paul committed
29
} // namespace migraph
Paul's avatar
Paul committed
30
31

#endif