simplify_reshapes.cpp 9.83 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
7
#include <migraphx/op/slice.hpp>
Paul's avatar
Paul committed
8
9
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/matcher.hpp>
11
#include <migraphx/permutation.hpp>
12
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
13
#include <unordered_set>
14
#include <migraphx/make_op.hpp>
15
#include <migraphx/tune_axis.hpp>
16

17
#include <map>
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
namespace migraphx {
Paul's avatar
Paul committed
20
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
21

Paul's avatar
Paul committed
22
const auto& reshaper_names()
Paul's avatar
Paul committed
23
{
24
25
    // clang-format off
    static const std::unordered_set<std::string> names = {
26
        "flatten",
27
        "reshape",
28
29
30
        "contiguous",
        "squeeze",
        "unsqueeze"
31
32
    };
    // clang-format on
Paul's avatar
Paul committed
33
    return names;
Paul's avatar
Paul committed
34
35
}

Paul's avatar
Paul committed
36
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
37
38
39

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
40
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
41
        return ins;
Paul's avatar
Paul committed
42
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
43
44
45
46
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
47
48
}

49
50
51
52
53
54
55
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
56
    if(dims.empty())
57
        return true;
Paul's avatar
Paul committed
58
    if(dims.front() != 0)
59
        return false;
Paul's avatar
Paul committed
60
61
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
62
63
}

Paul's avatar
Paul committed
64
struct find_reshaper
Paul's avatar
Paul committed
65
{
Paul's avatar
Paul committed
66
    auto matcher() const
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
69
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
70
71
    }

72
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
73
74
75
76
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
77
        {
Paul's avatar
Paul committed
78
79
80
81
82
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
83

Paul's avatar
Paul committed
84
85
86
87
88
89
90
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
91
            {
Paul's avatar
Paul committed
92
93
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
94
95
            }
        }
Paul's avatar
Paul committed
96
        if(r.first != r.second)
Paul's avatar
Paul committed
97
        {
Paul's avatar
Paul committed
98
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
99
        }
Paul's avatar
Paul committed
100
101
102
    }
};

Paul's avatar
Paul committed
103
104
105
106
107
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
108
109
110
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
Paul Fultz II's avatar
Paul Fultz II committed
111
        reshapes.insert("convert");
112
113
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
Paul's avatar
Paul committed
114
        reshapes.insert("slice");
115
        reshapes.insert("transpose");
Paul's avatar
Paul committed
116
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
117
118
    }

119
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
120
121
122
123
124
125
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
126
127
128
129
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
130
131
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
132
133
    }

134
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
135
136
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
137
138
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
154
        {
155
            p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
Paul's avatar
Paul committed
156
        }
Paul's avatar
Paul committed
157
    }
Paul's avatar
Paul committed
158
159
};

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
struct find_nested_convert
{
    auto matcher() const { return match::name("convert")(match::arg(0)(match::name("convert"))); }

    void apply(module& m, const match::matcher_result& mr) const
    {
        auto ins   = mr.result;
        auto x     = ins->inputs().front();
        auto input = x->inputs().front();

        if(ins->get_shape() != input->get_shape())
            return;

        m.replace_instruction(ins, input);
    }
};

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
struct find_nested_slice
{
    auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }

    using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>;

    static axes_map get_axes(instruction_ref ins)
    {
        axes_map result;
        auto op = any_cast<op::slice>(ins->get_operator());
        for(std::size_t i = 0; i < op.axes.size(); i++)
        {
            result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
        }
        return result;
    }

    static axes_map merge(const axes_map& m1, const axes_map& m2)
    {
        axes_map result;
        // Non overlapping
        for(auto&& p : m1)
        {
            if(contains(m2, p.first))
                continue;
            result[p.first] = p.second;
        }
        for(auto&& p : m2)
        {
            if(contains(m1, p.first))
                continue;
            result[p.first] = p.second;
        }
        // Overlapping
        for(auto&& p1 : m1)
        {
            if(not contains(m2, p1.first))
                continue;
            auto&& v1        = p1.second;
            auto&& v2        = m2.at(p1.first);
            auto start       = v1.first + v2.first;
            auto end         = start + (v2.second - v2.first);
            result[p1.first] = std::make_pair(start, end);
        }
        return result;
    }

224
    void apply(module& p, const match::matcher_result& mr) const
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    {
        auto ins   = mr.result;
        auto slice = ins->inputs().front();
        auto input = slice->inputs().front();

        auto a1 = get_axes(ins);
        auto a2 = get_axes(slice);

        auto axes = merge(a2, a1);

        auto op = op::slice{};
        for(auto&& pp : axes)
        {
            op.axes.push_back(pp.first);
            op.starts.push_back(pp.second.first);
            op.ends.push_back(pp.second.second);
        }
        p.replace_instruction(ins, op, input);
    }
};

Paul's avatar
Paul committed
246
247
248
249
struct find_concat_transpose
{
    auto matcher() const
    {
250
        return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
251
252
    }

253
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
254
    {
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
        auto ins          = mr.result;
        auto trans_inputs = ins->inputs();
        auto s            = trans_inputs.front()->get_shape();
Paul's avatar
Paul committed
258
        assert(s.transposed());
Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
262
263
264
265
266
267
268
269
270
271
        auto op          = any_cast<op::concat>(ins->get_operator());
        auto permutation = find_permutation(s);

        // permutation should be the same for all inputs
        if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
               return (find_permutation(in->get_shape()) == permutation);
           }))
        {
            return;
        }

        // axis could be a negative value
        int64_t n_dim = static_cast<int64_t>(s.lens().size());
272
        op.axis       = tune_axis(n_dim, op.axis, op.name());
Shucai Xiao's avatar
Shucai Xiao committed
273

Paul's avatar
Paul committed
274
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
275
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
276
277
278

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
279
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
280
                return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
Paul's avatar
Paul committed
281
            });
Paul's avatar
Paul committed
282
        auto concat = p.insert_instruction(ins, op, inputs);
283
        auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
Paul's avatar
Paul committed
284
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
285
286
287
288
        p.replace_instruction(ins, t);
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
289
290
291
292
293
294
295
296
297
298
299
300
301
struct find_nested_concat
{
    auto matcher() const
    {
        return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
    }

    static std::size_t get_axis(instruction_ref ins)
    {
        auto op = any_cast<op::concat>(ins->get_operator());
        return op.axis;
    }

302
    void apply(module& p, const match::matcher_result& mr) const
Paul Fultz II's avatar
Paul Fultz II committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    {
        auto ins  = mr.result;
        auto axis = get_axis(ins);
        std::vector<instruction_ref> args;
        fix([&](auto self, auto&& inputs) {
            for(auto&& i : inputs)
            {
                if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
                    self(i->inputs());
                else
                    args.push_back(i);
            }

        })(ins->inputs());
        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

321
void simplify_reshapes::apply(module& p) const
Paul's avatar
Paul committed
322
{
323
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
324
    {
325
326
327
328
329
        match::find_matches(p,
                            find_nop_reshapes{},
                            find_reshaper{},
                            find_transpose{},
                            find_concat_transpose{},
330
                            find_nested_convert{},
331
332
333
                            find_nested_slice{},
                            find_nested_concat{});
        dead_code_elimination{}.apply(p);
Paul's avatar
Paul committed
334
    }
Paul's avatar
Paul committed
335
336
}

Paul's avatar
Paul committed
337
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
338
} // namespace migraphx