reduce_dims.cpp 2.96 KB
Newer Older
1
2
3
4
5
#include <migraphx/reduce_dims.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

6
bool reduce_dim(std::vector<shape>& shapes, int n)
7
{
8
    std::vector<int> new_lens;
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    for(const auto& s : shapes)
    {
        assert(n < s.lens().size());
        if((n + 1) >= s.lens().size())
            return false;
        auto astride = s.strides()[n];
        auto alen    = s.lens()[n];
        auto bstride = s.strides()[n + 1];
        auto blen    = s.lens()[n + 1];

        if(astride == bstride * blen)
        {
            new_lens.push_back(alen * blen);
        }
    }
    if(new_lens.size() != shapes.size())
        return false;
26
    int i = 0;
27
28
29
30
31
32
33
34
35
36
37
38
39
    for(auto& s : shapes)
    {
        auto lens    = s.lens();
        auto strides = s.strides();
        lens.erase(lens.begin() + n);
        strides.erase(strides.begin() + n);
        lens[n] = new_lens[i];
        s       = shape{s.type(), lens, strides};
        i++;
    }
    return true;
}

40
int reduce_dim_all(std::vector<shape>& shapes, int n)
41
42
43
44
45
46
47
48
49
{
    while(reduce_dim(shapes, n) and n < shapes.size())
    {
    }

    return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
{
50
    int n = 0;
51
52
53
54
    while(n < shapes.front().lens().size() - 1)
        n = reduce_dim_all(shapes, n);
}

55
std::vector<int> base_lens(const std::vector<shape>& shapes)
56
57
58
{
    return std::accumulate(
        shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
59
            std::vector<int> result;
60
61
62
63
64
65
66
67
68
69
70
71
            const auto* x = &s.lens();
            const auto* y = &lens;
            if(x->size() > y->size())
                std::swap(x, y);
            std::transform(
                x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
                    return std::max(a, b);
                });
            return result;
        });
}

72
shape mask_shape(const shape& s, const std::vector<int>& lens)
73
74
{
    assert(s.lens().size() == lens.size());
75
76
77
    std::vector<int> rstrides(lens.size());
    int stride = 1;
    for(int i = lens.size() - 1; i < lens.size(); i--)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    {
        if(lens[i] == s.lens()[i])
        {
            rstrides[i] = stride;
            stride *= lens[i];
        }
        else if(lens[i] != 1 and s.lens()[i] != 1)
        {
            return shape{};
        }
    }
    return shape{s.type(), lens, rstrides};
}

std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
{
    if(shapes.empty())
        return {};
    auto result = shapes;
    auto base   = base_lens(shapes);
    for(auto&& s : shapes)
    {
        if(s.lens().size() != base.size())
            return shapes;
        if(s.lens() == base)
            continue;
        auto mshape = mask_shape(s, base);
        if(mshape.lens().size() != base.size())
            return shapes;
        result.push_back(mshape);
    }
    reduce_dim_all(result);
    result.erase(result.begin() + shapes.size(), result.end());
    return result;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx