"src/vscode:/vscode.git/clone" did not exist on "6ddcc2e291a12c58eeb083920e74e22c8353f509"
permutation.hpp 1.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#ifndef MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP

#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class Vector>
inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permutation)
{
    Vector result(dims.size());
    assert(dims.size() == permutation.size());
Shucai Xiao's avatar
Shucai Xiao committed
16
    for(int i = 0; i < dims.size(); i++)
17
18
19
20
21
22
    {
        result[i] = dims[permutation[i]];
    }
    return result;
}

23
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation);
24
25
26
27
28
29

template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
    std::vector<std::int64_t> result(data.size());
    std::iota(result.begin(), result.end(), 0);
30
31
    std::stable_sort(
        result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
32
33
34
    return result;
}

35
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
36

37
38
std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
39
40
41
42
43

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif