simplify_reshapes.cpp 6.02 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
9
10
#include <unordered_set>

Paul's avatar
Paul committed
11
namespace migraphx {
Paul's avatar
Paul committed
12
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
bool is_reshaper(instruction_ref ins)
Paul's avatar
Paul committed
15
{
16
17
18
    // clang-format off
    static const std::unordered_set<std::string> names = {
        "reshape",
19
20
21
        "contiguous",
        "squeeze",
        "unsqueeze"
22
23
    };
    // clang-format on
Paul's avatar
Paul committed
24
25
26
27
28
    return contains(names, ins->name());
}

bool is_transpose_output(instruction_ref ins)
{
Paul's avatar
Paul committed
29
    if(ins->outputs().size() != 1)
Paul's avatar
Paul committed
30
        return false;
Paul's avatar
Paul committed
31
    if(ins->outputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
32
33
34
35
36
37
        return is_transpose_output(ins->outputs().front());
    return ins->outputs().front()->name() == "transpose";
}

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
38
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
39
        return ins;
Paul's avatar
Paul committed
40
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
41
42
43
44
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
45
46
}

47
48
49
50
51
52
53
54
55
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
    std::vector<int64_t> result(dims.size());
    assert(dims.size() == permutation.size());
Paul's avatar
Paul committed
56
    for(std::size_t i = 0; i < dims.size(); i++)
57
    {
Paul's avatar
Paul committed
58
        result[i] = dims[permutation[i]];
59
60
61
62
    }
    return result;
}

Paul's avatar
Paul committed
63
64
65
66
67
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
    return reorder_dims(permutation, permutation);
}

68
69
bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
70
    if(dims.empty())
71
        return true;
Paul's avatar
Paul committed
72
    if(dims.front() != 0)
73
        return false;
Paul's avatar
Paul committed
74
75
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
76
77
}

Paul's avatar
Paul committed
78
template <class Vector, class Op>
Paul's avatar
Paul committed
79
80
81
82
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
    std::vector<std::int64_t> result(data.size());
    std::iota(result.begin(), result.end(), 0);
Paul's avatar
Paul committed
83
    std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
    return result;
}

std::vector<int64_t> find_permutation(const shape& s)
{
    return sort_permutation(s.strides(), std::greater<>{});
}

Paul's avatar
Paul committed
92
93
void simplify_reshapes::apply(program& p) const
{
Paul's avatar
Paul committed
94
    auto end = std::prev(p.end());
Paul's avatar
Paul committed
95
96
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
97
        if(ins == end and ins->name() == "contiguous")
Paul's avatar
Paul committed
98
99
            continue;
        // Skip possible dead instructions
Paul's avatar
Paul committed
100
        if(ins->outputs().empty() and ins != end)
Paul's avatar
Paul committed
101
            continue;
Paul's avatar
Paul committed
102
        if(is_reshaper(ins))
Paul's avatar
Paul committed
103
        {
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
112
113
114
            if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
                continue;
            // Gather reshapes
            std::vector<instruction_ref> reshapes{ins};
            while(is_reshaper(reshapes.back()))
            {
                assert(!reshapes.back()->inputs().empty());
                assert(p.has_instruction(reshapes.back()->inputs().front()));
                auto input = reshapes.back()->inputs().front();
                reshapes.push_back(input);
            }
Paul's avatar
Paul committed
115

Paul's avatar
Paul committed
116
117
            std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
            for(auto start : iterator_for(reshapes))
Paul's avatar
Paul committed
118
            {
Paul's avatar
Paul committed
119
120
121
122
123
124
125
126
127
128
129
130
                auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                    return i->get_shape() == (*start)->get_shape() and i != (*start);
                });
                if(last != reshapes.rend())
                {
                    r = std::make_pair(*start, *last);
                    break;
                }
            }
            if(r.first != r.second)
            {
                p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
131
132
            }
        }
Paul's avatar
Paul committed
133
        else if(ins->name() == "transpose")
Paul's avatar
Paul committed
134
        {
Paul's avatar
Paul committed
135
            if(is_transpose_output(ins))
Paul's avatar
Paul committed
136
137
138
                continue;
            auto x = ins;
            auto t = ins;
139
140
            std::vector<std::int64_t> dims(ins->get_shape().lens().size());
            std::iota(dims.begin(), dims.end(), 0);
Paul's avatar
Paul committed
141
142
            do
            {
143
                dims = reorder_dims(get_transpose_dims(t), dims);
Paul's avatar
Paul committed
144
145
                x    = t;
                t    = find_transpose_input(x);
Paul's avatar
Paul committed
146
            } while(x != t and t->name() == "transpose");
Paul's avatar
Paul committed
147
            if(t == ins or t->name() != "transpose")
Paul's avatar
Paul committed
148
                continue;
Paul's avatar
Paul committed
149
            if(is_no_transpose(dims))
150
151
152
153
154
155
156
            {
                p.replace_instruction(ins, t->inputs().front());
            }
            else
            {
                p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
            }
Paul's avatar
Paul committed
157
        }
Paul's avatar
Paul committed
158
159
        else if(ins->name() == "concat")
        {
Paul's avatar
Paul committed
160
            if(ins->inputs().empty())
Paul's avatar
Paul committed
161
162
                continue;
            auto s = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
163
164
            if(none_of(ins->inputs(), [&](auto i) { return i->get_shape().transposed(); }) or
               none_of(ins->inputs(), [&](auto i) { return i->get_shape() == s; }))
Paul's avatar
Paul committed
165
                continue;
Paul's avatar
Paul committed
166
            auto op          = any_cast<op::concat>(ins->get_operator());
Paul's avatar
Paul committed
167
168
            auto permutation = find_permutation(s);
            auto ipermutaion = invert_permutation(permutation);
Paul's avatar
Paul committed
169
            op.axis          = ipermutaion[op.axis];
Paul's avatar
Paul committed
170
171

            std::vector<instruction_ref> inputs;
Paul's avatar
Paul committed
172
173
174
175
176
            std::transform(
                ins->inputs().begin(),
                ins->inputs().end(),
                std::back_inserter(inputs),
                [&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
Paul's avatar
Paul committed
177
            auto concat = p.insert_instruction(ins, op, inputs);
Paul's avatar
Paul committed
178
            auto t      = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
Paul's avatar
Paul committed
179
180
            p.replace_instruction(ins, t);
        }
Paul's avatar
Paul committed
181
182
183
    }
}

Paul's avatar
Paul committed
184
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
185
} // namespace migraphx