"docs/_removed/TrialExample/MnistExamples.rst" did not exist on "7b2cac912cb6d6efb4fa1c7b729624d5dc0f0d69"
common_dims.cpp 3.78 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
    std::size_t x = 1;
    auto it       = std::find_if(start, last, [&](auto i) {
        x *= i;
Paul's avatar
Paul committed
16
        return x > dim;
Paul's avatar
Paul committed
17
    });
Paul's avatar
Paul committed
18
    if(x < dim)
Paul's avatar
Paul committed
19
20
21
22
        return start;
    return it;
}

Paul's avatar
Format  
Paul committed
23
template <class Range>
Paul's avatar
Paul committed
24
25
static auto elements(const Range& r)
{
Paul's avatar
Paul committed
26
    return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
Paul's avatar
Paul committed
27
28
}

Paul's avatar
Paul committed
29
30
struct common_dim_state
{
Paul's avatar
Format  
Paul committed
31
32
33
34
35
36
    common_dim_state(const std::vector<std::size_t>& pdims,
                     std::vector<std::vector<std::size_t>>& paxes_map)
        : dims(&pdims), axes_map(&paxes_map), it(dims->begin())
    {
    }
    const std::vector<std::size_t>* dims            = nullptr;
Paul's avatar
Paul committed
37
38
    std::vector<std::vector<std::size_t>>* axes_map = nullptr;
    std::vector<std::size_t>::const_iterator it{};
Paul's avatar
Paul committed
39
    std::size_t rem = 1;
Paul's avatar
Paul committed
40
    std::size_t get() const { return *it / rem; }
Paul's avatar
Format  
Paul committed
41
42
    bool is_end() const { return it == dims->end(); }
    void next(std::size_t i = 1) { it += i; }
Paul's avatar
Paul committed
43
44
45
46
47
    auto dims_for(std::size_t d) const
    {
        auto dim_end = compute_end_dim(it, dims->end(), d);
        return range(it, dim_end);
    }
Paul's avatar
Paul committed
48
49
50
51
52
53
54
55
56
    void add_axes(std::size_t naxes, std::size_t start)
    {
        auto axes = compute_axes(naxes, start);
        axes_map->push_back(std::move(axes));
    }

    void add_multi_axes(std::size_t naxes, std::size_t start)
    {
        auto axes = compute_axes(naxes, start);
Paul's avatar
Format  
Paul committed
57
58
59
60
        std::transform(axes.begin(),
                       axes.end(),
                       std::back_inserter(*axes_map),
                       [&](auto axis) -> std::vector<std::size_t> { return {axis}; });
Paul's avatar
Paul committed
61
62
63
    }
    std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
    {
Paul's avatar
Format  
Paul committed
64
65
        if(rem != 1)
        {
Paul's avatar
Paul committed
66
67
68
69
70
71
72
73
            assert(start > 0);
            naxes++;
            start--;
        }
        std::vector<std::size_t> axes(naxes);
        std::iota(axes.begin(), axes.end(), start);
        return axes;
    }
Paul's avatar
Paul committed
74
};
Paul's avatar
Paul committed
75

Paul's avatar
Format  
Paul committed
76
77
78
static bool commpute_common_dim(std::vector<std::size_t>& cd_dims,
                                common_dim_state& state1,
                                common_dim_state& state2)
Paul's avatar
Paul committed
79
80
{
    assert(state1.get() <= state2.get());
Paul's avatar
Format  
Paul committed
81
82
83
    auto d2    = state2.get();
    auto dims  = state1.dims_for(d2);
    auto n     = elements(dims);
Paul's avatar
Paul committed
84
85
86
87
88
89
90
    auto naxes = distance(dims);
    // If not divisible then we can't compute a common dim
    if((d2 % n) != 0)
        return true;
    auto rem = d2 / n;
    state1.add_multi_axes(naxes, cd_dims.size());
    state2.add_axes(rem != 1 ? naxes + 1 : naxes, cd_dims.size());
Paul's avatar
Format  
Paul committed
91

Paul's avatar
Paul committed
92
93
94
95
96
97
98
99
100
101
102
103
104
    state1.rem = rem;
    state2.rem = 1;

    cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
    if(state1.rem != 1)
        cd_dims.push_back(state1.rem);
    state1.next(distance(dims));
    state2.next();
    return false;
}

common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
                                 const std::vector<std::size_t>& dims2)
Paul's avatar
Paul committed
105
106
107
{
    assert(elements(dims1) == elements(dims2));
    common_dims cd;
Paul's avatar
Paul committed
108
109
110
    common_dim_state state1{dims1, cd.axes_map1};
    common_dim_state state2{dims2, cd.axes_map2};
    while(not state1.is_end() and not state2.is_end())
Paul's avatar
Paul committed
111
    {
Paul's avatar
Paul committed
112
113
114
        auto d1 = state1.get();
        auto d2 = state2.get();
        if(d1 <= d2)
Paul's avatar
Paul committed
115
        {
Paul's avatar
Format  
Paul committed
116
            if(commpute_common_dim(cd.dims, state1, state2))
Paul's avatar
Paul committed
117
                return {};
Paul's avatar
Paul committed
118
        }
Paul's avatar
Paul committed
119
        else // if(d1 > d2)
Paul's avatar
Paul committed
120
        {
Paul's avatar
Format  
Paul committed
121
            if(commpute_common_dim(cd.dims, state2, state1))
Paul's avatar
Paul committed
122
                return {};
Paul's avatar
Paul committed
123
124
        }
    }
Paul's avatar
Paul committed
125
    assert(elements(dims1) == elements(cd.dims));
Paul's avatar
Paul committed
126
127
128
129
130
    return cd;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx