reduce_dims.cpp 3.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <migraphx/reduce_dims.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
{
    std::vector<std::size_t> new_lens;
    for(const auto& s : shapes)
    {
        assert(n < s.lens().size());
        if((n + 1) >= s.lens().size())
            return false;
        auto astride = s.strides()[n];
        auto alen    = s.lens()[n];
        auto bstride = s.strides()[n + 1];
        auto blen    = s.lens()[n + 1];

19
        if(astride == bstride * blen or alen == 1)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
            new_lens.push_back(alen * blen);
    }
    if(new_lens.size() != shapes.size())
        return false;
    std::size_t i = 0;
    for(auto& s : shapes)
    {
        auto lens    = s.lens();
        auto strides = s.strides();
        lens.erase(lens.begin() + n);
        strides.erase(strides.begin() + n);
        lens[n] = new_lens[i];
        s       = shape{s.type(), lens, strides};
        i++;
    }
    return true;
}

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
void reduce_dim1(std::vector<shape>& shapes)
{
    if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
           return s.lens().size() < 2 or s.lens().back() != 1;
       }))
        return;
    for(auto& s : shapes)
    {
        auto lens    = s.lens();
        auto strides = s.strides();
        lens.pop_back();
        strides.pop_back();
        s = shape{s.type(), lens, strides};
    }
}

54
55
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
bpickrel's avatar
bpickrel committed
56
    while(reduce_dim(shapes, n) and n < shapes.size()) {}
57
58
59
60
61
62
63
    return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
{
    std::size_t n = 0;
    while(n < shapes.front().lens().size() - 1)
        n = reduce_dim_all(shapes, n);
64
    reduce_dim1(shapes);
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
}

std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
{
    return std::accumulate(
        shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
            std::vector<std::size_t> result;
            const auto* x = &s.lens();
            const auto* y = &lens;
            if(x->size() > y->size())
                std::swap(x, y);
            std::transform(
                x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
                    return std::max(a, b);
                });
            return result;
        });
}

shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
{
    assert(s.lens().size() == lens.size());
    std::vector<std::size_t> rstrides(lens.size());
    std::size_t stride = 1;
    for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
    {
        if(lens[i] == s.lens()[i])
        {
            rstrides[i] = stride;
            stride *= lens[i];
        }
        else if(lens[i] != 1 and s.lens()[i] != 1)
        {
            return shape{};
        }
    }
    return shape{s.type(), lens, rstrides};
}

std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
{
    if(shapes.empty())
        return {};
    auto result = shapes;
    auto base   = base_lens(shapes);
    for(auto&& s : shapes)
    {
        if(s.lens().size() != base.size())
            return shapes;
        if(s.lens() == base)
            continue;
        auto mshape = mask_shape(s, base);
        if(mshape.lens().size() != base.size())
            return shapes;
        result.push_back(mshape);
    }
    reduce_dim_all(result);
    result.erase(result.begin() + shapes.size(), result.end());
    return result;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx