simplify_reshapes.cpp 6.32 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/matcher.hpp>
Paul's avatar
Paul committed
10
11
#include <unordered_set>

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
const auto& reshaper_names()
Paul's avatar
Paul committed
16
{
17
18
19
    // clang-format off
    static const std::unordered_set<std::string> names = {
        "reshape",
20
21
22
        "contiguous",
        "squeeze",
        "unsqueeze"
23
24
    };
    // clang-format on
Paul's avatar
Paul committed
25
    return names;
Paul's avatar
Paul committed
26
27
}

Paul's avatar
Paul committed
28
bool is_reshaper(instruction_ref ins)
Paul's avatar
Paul committed
29
{
Paul's avatar
Paul committed
30
    return contains(reshaper_names(), ins->name());
Paul's avatar
Paul committed
31
32
33
34
}

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
35
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
36
        return ins;
Paul's avatar
Paul committed
37
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
38
39
40
41
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
42
43
}

44
45
46
47
48
49
50
51
52
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
    std::vector<int64_t> result(dims.size());
    assert(dims.size() == permutation.size());
Paul's avatar
Paul committed
53
    for(std::size_t i = 0; i < dims.size(); i++)
54
    {
Paul's avatar
Paul committed
55
        result[i] = dims[permutation[i]];
56
57
58
59
    }
    return result;
}

Paul's avatar
Paul committed
60
61
62
63
64
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
    return reorder_dims(permutation, permutation);
}

65
66
bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
67
    if(dims.empty())
68
        return true;
Paul's avatar
Paul committed
69
    if(dims.front() != 0)
70
        return false;
Paul's avatar
Paul committed
71
72
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
73
74
}

Paul's avatar
Paul committed
75
template <class Vector, class Op>
Paul's avatar
Paul committed
76
77
78
79
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
    std::vector<std::int64_t> result(data.size());
    std::iota(result.begin(), result.end(), 0);
Paul's avatar
Paul committed
80
    std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
Paul's avatar
Paul committed
81
82
83
84
85
86
87
88
    return result;
}

std::vector<int64_t> find_permutation(const shape& s)
{
    return sort_permutation(s.strides(), std::greater<>{});
}

Paul's avatar
Paul committed
89
struct find_reshaper
Paul's avatar
Paul committed
90
{
Paul's avatar
Paul committed
91
    auto matcher() const
Paul's avatar
Paul committed
92
    {
Paul's avatar
Paul committed
93
94
95
96
97
98
99
100
        return match::name(reshaper_names())(match::any_of[match::outputs()](match::name(reshaper_names())));
    }

    void apply(program& p, match::matcher_result mr) const
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
101
        {
Paul's avatar
Paul committed
102
103
104
105
106
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
107

Paul's avatar
Paul committed
108
109
110
111
112
113
114
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
115
            {
Paul's avatar
Paul committed
116
117
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
118
119
            }
        }
Paul's avatar
Paul committed
120
        if(r.first != r.second)
Paul's avatar
Paul committed
121
        {
Paul's avatar
Paul committed
122
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
123
        }
Paul's avatar
Paul committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    }
};

MIGRAPHX_PRED_MATCHER(is_transpose_output, instruction_ref start)
{
    return fix<bool>([&](auto self, auto ins) {
        if(ins->outputs().size() != 1)
            return false;
        if(ins->outputs().front()->name() == "contiguous")
            return self(ins->outputs().front());
        return ins->outputs().front()->name() == "transpose";
    })(start);
}

struct find_transpose
{
    auto matcher() const
    {
        return match::name("transpose")(match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose"))));
    }

    void apply(program& p, match::matcher_result mr) const
    {
        auto ins = mr.result;
        auto x = ins;
        auto t = ins;
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
165
        {
Paul's avatar
Paul committed
166
            p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
Paul's avatar
Paul committed
167
        }
Paul's avatar
Paul committed
168
    }
Paul's avatar
Paul committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
};

struct find_concat_transpose
{
    auto matcher() const
    {
        return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape()));
    }

    void apply(program& p, match::matcher_result mr) const
    {
        auto ins = mr.result;
        auto s = ins->inputs().front()->get_shape();

        auto op          = any_cast<op::concat>(ins->get_operator());
        auto permutation = find_permutation(s);
        auto ipermutaion = invert_permutation(permutation);
        op.axis          = ipermutaion[op.axis];

        std::vector<instruction_ref> inputs;
        std::transform(
            ins->inputs().begin(),
            ins->inputs().end(),
            std::back_inserter(inputs),
            [&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
        auto concat = p.insert_instruction(ins, op, inputs);
        auto t      = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
        p.replace_instruction(ins, t);
    }
};

void simplify_reshapes::apply(program& p) const
{
    auto end = std::prev(p.end());
    for(auto ins : iterator_for(p))
    {
        if(ins == end and ins->name() == "contiguous")
            continue;
        // Skip possible dead instructions
        if(ins->outputs().empty() and ins != end)
            continue;
        match::find_matches(p, ins, 
            find_reshaper{},
            find_transpose{},
            find_concat_transpose{}
        );
    }
Paul's avatar
Paul committed
216
217
}

Paul's avatar
Paul committed
218
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
219
} // namespace migraphx