"src/targets/vscode:/vscode.git/clone" did not exist on "c1a76b61c8956046ba4ecaad5d37cbe6ba71a528"
permutation.hpp 1.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#ifndef MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP

#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class Vector>
inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permutation)
{
    Vector result(dims.size());
    assert(dims.size() == permutation.size());
    for(std::size_t i = 0; i < dims.size(); i++)
    {
        result[i] = dims[permutation[i]];
    }
    return result;
}

inline shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
    return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}

template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
    std::vector<std::int64_t> result(data.size());
    std::iota(result.begin(), result.end(), 0);
    std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
    return result;
}

inline std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
    return sort_permutation(permutation, std::less<>{});
}

inline std::vector<int64_t> find_permutation(const shape& s)
{
    return sort_permutation(s.strides(), std::greater<>{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif