common_dims.cpp 3.62 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
    std::size_t x = 1;
    auto it       = std::find_if(start, last, [&](auto i) {
        x *= i;
Paul's avatar
Paul committed
16
        return x > dim;
Paul's avatar
Paul committed
17
    });
Paul's avatar
Paul committed
18
    if(x < dim)
Paul's avatar
Paul committed
19
20
21
22
        return start;
    return it;
}

Paul's avatar
Format  
Paul committed
23
template <class Range>
Paul's avatar
Paul committed
24
25
static auto elements(const Range& r)
{
Paul's avatar
Paul committed
26
    return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
Paul's avatar
Paul committed
27
28
}

Paul's avatar
Paul committed
29
30
struct common_dim_state
{
Paul's avatar
Paul committed
31
32
33
34
    common_dim_state(const std::vector<std::size_t>& pdims, std::vector<std::vector<std::size_t>>& paxes_map) : dims(&pdims), axes_map(&paxes_map), it(dims->begin()) {}
    const std::vector<std::size_t>* dims = nullptr;
    std::vector<std::vector<std::size_t>>* axes_map = nullptr;
    std::vector<std::size_t>::const_iterator it{};
Paul's avatar
Paul committed
35
    std::size_t rem = 1;
Paul's avatar
Paul committed
36
    std::size_t get() const { return *it / rem; }
Paul's avatar
Format  
Paul committed
37
38
    bool is_end() const { return it == dims->end(); }
    void next(std::size_t i = 1) { it += i; }
Paul's avatar
Paul committed
39
40
41
42
43
    auto dims_for(std::size_t d) const
    {
        auto dim_end = compute_end_dim(it, dims->end(), d);
        return range(it, dim_end);
    }
Paul's avatar
Paul committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    void add_axes(std::size_t naxes, std::size_t start)
    {
        auto axes = compute_axes(naxes, start);
        axes_map->push_back(std::move(axes));
    }

    void add_multi_axes(std::size_t naxes, std::size_t start)
    {
        auto axes = compute_axes(naxes, start);
        std::transform(axes.begin(), axes.end(), std::back_inserter(*axes_map), [&](auto axis) -> std::vector<std::size_t> {
            return {axis};
        });
    }
    std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
    {
        if (rem != 1) {
            assert(start > 0);
            naxes++;
            start--;
        }
        std::vector<std::size_t> axes(naxes);
        std::iota(axes.begin(), axes.end(), start);
        return axes;
    }
Paul's avatar
Paul committed
68
};
Paul's avatar
Paul committed
69

Paul's avatar
Paul committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_state& state1, common_dim_state& state2)
{
    assert(state1.get() <= state2.get());
    auto d2 = state2.get();
    auto dims    = state1.dims_for(d2);
    auto n       = elements(dims);
    auto naxes = distance(dims);
    // If not divisible then we can't compute a common dim
    if((d2 % n) != 0)
        return true;
    auto rem = d2 / n;
    state1.add_multi_axes(naxes, cd_dims.size());
    state2.add_axes(rem != 1 ? naxes + 1 : naxes, cd_dims.size());
    
    state1.rem = rem;
    state2.rem = 1;

    cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
    if(state1.rem != 1)
        cd_dims.push_back(state1.rem);
    state1.next(distance(dims));
    state2.next();
    return false;
}

common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
                                 const std::vector<std::size_t>& dims2)
Paul's avatar
Paul committed
97
98
99
{
    assert(elements(dims1) == elements(dims2));
    common_dims cd;
Paul's avatar
Paul committed
100
101
102
    common_dim_state state1{dims1, cd.axes_map1};
    common_dim_state state2{dims2, cd.axes_map2};
    while(not state1.is_end() and not state2.is_end())
Paul's avatar
Paul committed
103
    {
Paul's avatar
Paul committed
104
105
106
        auto d1 = state1.get();
        auto d2 = state2.get();
        if(d1 <= d2)
Paul's avatar
Paul committed
107
        {
Paul's avatar
Paul committed
108
109
            if (commpute_common_dim(cd.dims, state1, state2))
                return {};
Paul's avatar
Paul committed
110
        }
Paul's avatar
Paul committed
111
        else // if(d1 > d2)
Paul's avatar
Paul committed
112
        {
Paul's avatar
Paul committed
113
114
            if (commpute_common_dim(cd.dims, state2, state1))
                return {};
Paul's avatar
Paul committed
115
116
        }
    }
Paul's avatar
Paul committed
117
    assert(elements(dims1) == elements(cd.dims));
Paul's avatar
Paul committed
118
119
120
121
122
    return cd;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx