simplify_reshapes.cpp 8.72 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
7
#include <migraphx/op/slice.hpp>
Paul's avatar
Paul committed
8
9
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/matcher.hpp>
11
#include <migraphx/permutation.hpp>
12
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
13
#include <unordered_set>
14
#include <map>
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
namespace migraphx {
Paul's avatar
Paul committed
17
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
const auto& reshaper_names()
Paul's avatar
Paul committed
20
{
21
22
    // clang-format off
    static const std::unordered_set<std::string> names = {
23
        "flatten",
24
        "reshape",
25
26
27
        "contiguous",
        "squeeze",
        "unsqueeze"
28
29
    };
    // clang-format on
Paul's avatar
Paul committed
30
    return names;
Paul's avatar
Paul committed
31
32
}

Paul's avatar
Paul committed
33
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
34
35
36

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
37
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
38
        return ins;
Paul's avatar
Paul committed
39
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
40
41
42
43
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
44
45
}

46
47
48
49
50
51
52
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
53
    if(dims.empty())
54
        return true;
Paul's avatar
Paul committed
55
    if(dims.front() != 0)
56
        return false;
Paul's avatar
Paul committed
57
58
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
59
60
}

Paul's avatar
Paul committed
61
struct find_reshaper
Paul's avatar
Paul committed
62
{
Paul's avatar
Paul committed
63
    auto matcher() const
Paul's avatar
Paul committed
64
    {
Paul's avatar
Paul committed
65
66
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
67
68
    }

Paul's avatar
Paul committed
69
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
70
71
72
73
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
74
        {
Paul's avatar
Paul committed
75
76
77
78
79
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
80

Paul's avatar
Paul committed
81
82
83
84
85
86
87
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
88
            {
Paul's avatar
Paul committed
89
90
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
91
92
            }
        }
Paul's avatar
Paul committed
93
        if(r.first != r.second)
Paul's avatar
Paul committed
94
        {
Paul's avatar
Paul committed
95
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
96
        }
Paul's avatar
Paul committed
97
98
99
    }
};

Paul's avatar
Paul committed
100
101
102
103
104
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
105
106
107
108
109
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
Paul's avatar
Paul committed
110
        reshapes.insert("slice");
111
        reshapes.insert("transpose");
Paul's avatar
Paul committed
112
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
113
114
    }

Paul's avatar
Paul committed
115
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
116
117
118
119
120
121
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
122
123
124
125
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
126
127
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
128
129
    }

Paul's avatar
Paul committed
130
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
131
132
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
133
134
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
150
        {
Paul's avatar
Paul committed
151
            p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
Paul's avatar
Paul committed
152
        }
Paul's avatar
Paul committed
153
    }
Paul's avatar
Paul committed
154
155
};

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
struct find_nested_slice
{
    auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }

    using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>;

    static axes_map get_axes(instruction_ref ins)
    {
        axes_map result;
        auto op = any_cast<op::slice>(ins->get_operator());
        for(std::size_t i = 0; i < op.axes.size(); i++)
        {
            result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
        }
        return result;
    }

    static axes_map merge(const axes_map& m1, const axes_map& m2)
    {
        axes_map result;
        // Non overlapping
        for(auto&& p : m1)
        {
            if(contains(m2, p.first))
                continue;
            result[p.first] = p.second;
        }
        for(auto&& p : m2)
        {
            if(contains(m1, p.first))
                continue;
            result[p.first] = p.second;
        }
        // Overlapping
        for(auto&& p1 : m1)
        {
            if(not contains(m2, p1.first))
                continue;
            auto&& v1        = p1.second;
            auto&& v2        = m2.at(p1.first);
            auto start       = v1.first + v2.first;
            auto end         = start + (v2.second - v2.first);
            result[p1.first] = std::make_pair(start, end);
        }
        return result;
    }

    void apply(program& p, const match::matcher_result& mr) const
    {
        auto ins   = mr.result;
        auto slice = ins->inputs().front();
        auto input = slice->inputs().front();

        auto a1 = get_axes(ins);
        auto a2 = get_axes(slice);

        auto axes = merge(a2, a1);

        auto op = op::slice{};
        for(auto&& pp : axes)
        {
            op.axes.push_back(pp.first);
            op.starts.push_back(pp.second.first);
            op.ends.push_back(pp.second.second);
        }
        p.replace_instruction(ins, op, input);
    }
};

Paul's avatar
Paul committed
225
226
227
228
struct find_concat_transpose
{
    auto matcher() const
    {
229
        return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
230
231
    }

Paul's avatar
Paul committed
232
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
233
234
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
235
        auto s   = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
236
        assert(s.transposed());
Paul's avatar
Paul committed
237
238
        auto op           = any_cast<op::concat>(ins->get_operator());
        auto permutation  = find_permutation(s);
Paul's avatar
Paul committed
239
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
240
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
241
242
243

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
244
245
246
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
                return p.insert_instruction(ins, op::transpose{permutation}, i);
            });
Paul's avatar
Paul committed
247
        auto concat = p.insert_instruction(ins, op, inputs);
Paul's avatar
Paul committed
248
249
        auto t      = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
250
251
252
253
        p.replace_instruction(ins, t);
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
struct find_nested_concat
{
    auto matcher() const
    {
        return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
    }

    static std::size_t get_axis(instruction_ref ins)
    {
        auto op = any_cast<op::concat>(ins->get_operator());
        return op.axis;
    }

    void apply(program& p, const match::matcher_result& mr) const
    {
        auto ins  = mr.result;
        auto axis = get_axis(ins);
        std::vector<instruction_ref> args;
        fix([&](auto self, auto&& inputs) {
            for(auto&& i : inputs)
            {
                if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
                    self(i->inputs());
                else
                    args.push_back(i);
            }

        })(ins->inputs());
        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

Paul's avatar
Paul committed
286
287
void simplify_reshapes::apply(program& p) const
{
288
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
289
    {
290
291
292
293
294
295
296
297
        match::find_matches(p,
                            find_nop_reshapes{},
                            find_reshaper{},
                            find_transpose{},
                            find_concat_transpose{},
                            find_nested_slice{},
                            find_nested_concat{});
        dead_code_elimination{}.apply(p);
Paul's avatar
Paul committed
298
    }
Paul's avatar
Paul committed
299
300
}

Paul's avatar
Paul committed
301
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
302
} // namespace migraphx