simplify_reshapes.cpp 9.32 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
7
#include <migraphx/op/slice.hpp>
Paul's avatar
Paul committed
8
9
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/matcher.hpp>
11
#include <migraphx/permutation.hpp>
12
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
13
#include <unordered_set>
14
15
#include <migraphx/make_op.hpp>

16
#include <map>
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
namespace migraphx {
Paul's avatar
Paul committed
19
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
const auto& reshaper_names()
Paul's avatar
Paul committed
22
{
23
24
    // clang-format off
    static const std::unordered_set<std::string> names = {
25
        "flatten",
26
        "reshape",
27
28
29
        "contiguous",
        "squeeze",
        "unsqueeze"
30
31
    };
    // clang-format on
Paul's avatar
Paul committed
32
    return names;
Paul's avatar
Paul committed
33
34
}

Paul's avatar
Paul committed
35
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
36
37
38

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
39
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
40
        return ins;
Paul's avatar
Paul committed
41
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
42
43
44
45
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
46
47
}

48
49
50
51
52
53
54
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
55
    if(dims.empty())
56
        return true;
Paul's avatar
Paul committed
57
    if(dims.front() != 0)
58
        return false;
Paul's avatar
Paul committed
59
60
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
61
62
}

Paul's avatar
Paul committed
63
struct find_reshaper
Paul's avatar
Paul committed
64
{
Paul's avatar
Paul committed
65
    auto matcher() const
Paul's avatar
Paul committed
66
    {
Paul's avatar
Paul committed
67
68
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
69
70
    }

71
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
72
73
74
75
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
76
        {
Paul's avatar
Paul committed
77
78
79
80
81
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
82

Paul's avatar
Paul committed
83
84
85
86
87
88
89
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
90
            {
Paul's avatar
Paul committed
91
92
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
93
94
            }
        }
Paul's avatar
Paul committed
95
        if(r.first != r.second)
Paul's avatar
Paul committed
96
        {
Paul's avatar
Paul committed
97
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
98
        }
Paul's avatar
Paul committed
99
100
101
    }
};

Paul's avatar
Paul committed
102
103
104
105
106
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
107
108
109
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
Paul Fultz II's avatar
Paul Fultz II committed
110
        reshapes.insert("convert");
111
112
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
Paul's avatar
Paul committed
113
        reshapes.insert("slice");
114
        reshapes.insert("transpose");
Paul's avatar
Paul committed
115
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
116
117
    }

118
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
119
120
121
122
123
124
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
125
126
127
128
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
129
130
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
131
132
    }

133
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
134
135
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
136
137
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
153
        {
154
            p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
Paul's avatar
Paul committed
155
        }
Paul's avatar
Paul committed
156
    }
Paul's avatar
Paul committed
157
158
};

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
struct find_nested_slice
{
    auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }

    using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>;

    static axes_map get_axes(instruction_ref ins)
    {
        axes_map result;
        auto op = any_cast<op::slice>(ins->get_operator());
        for(std::size_t i = 0; i < op.axes.size(); i++)
        {
            result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
        }
        return result;
    }

    static axes_map merge(const axes_map& m1, const axes_map& m2)
    {
        axes_map result;
        // Non overlapping
        for(auto&& p : m1)
        {
            if(contains(m2, p.first))
                continue;
            result[p.first] = p.second;
        }
        for(auto&& p : m2)
        {
            if(contains(m1, p.first))
                continue;
            result[p.first] = p.second;
        }
        // Overlapping
        for(auto&& p1 : m1)
        {
            if(not contains(m2, p1.first))
                continue;
            auto&& v1        = p1.second;
            auto&& v2        = m2.at(p1.first);
            auto start       = v1.first + v2.first;
            auto end         = start + (v2.second - v2.first);
            result[p1.first] = std::make_pair(start, end);
        }
        return result;
    }

206
    void apply(module& p, const match::matcher_result& mr) const
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    {
        auto ins   = mr.result;
        auto slice = ins->inputs().front();
        auto input = slice->inputs().front();

        auto a1 = get_axes(ins);
        auto a2 = get_axes(slice);

        auto axes = merge(a2, a1);

        auto op = op::slice{};
        for(auto&& pp : axes)
        {
            op.axes.push_back(pp.first);
            op.starts.push_back(pp.second.first);
            op.ends.push_back(pp.second.second);
        }
        p.replace_instruction(ins, op, input);
    }
};

Paul's avatar
Paul committed
228
229
230
231
struct find_concat_transpose
{
    auto matcher() const
    {
232
        return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
233
234
    }

235
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
236
    {
Shucai Xiao's avatar
Shucai Xiao committed
237
238
239
        auto ins          = mr.result;
        auto trans_inputs = ins->inputs();
        auto s            = trans_inputs.front()->get_shape();
Paul's avatar
Paul committed
240
        assert(s.transposed());
Shucai Xiao's avatar
Shucai Xiao committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        auto op          = any_cast<op::concat>(ins->get_operator());
        auto permutation = find_permutation(s);

        // permutation should be the same for all inputs
        if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
               return (find_permutation(in->get_shape()) == permutation);
           }))
        {
            return;
        }

        // axis could be a negative value
        int64_t n_dim = static_cast<int64_t>(s.lens().size());
        op.axis       = (op.axis < 0) ? (op.axis + n_dim) : op.axis;

Paul's avatar
Paul committed
256
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
257
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
258
259
260

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
261
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
262
                return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
Paul's avatar
Paul committed
263
            });
Paul's avatar
Paul committed
264
        auto concat = p.insert_instruction(ins, op, inputs);
265
        auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
Paul's avatar
Paul committed
266
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
267
268
269
270
        p.replace_instruction(ins, t);
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
271
272
273
274
275
276
277
278
279
280
281
282
283
struct find_nested_concat
{
    auto matcher() const
    {
        return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
    }

    static std::size_t get_axis(instruction_ref ins)
    {
        auto op = any_cast<op::concat>(ins->get_operator());
        return op.axis;
    }

284
    void apply(module& p, const match::matcher_result& mr) const
Paul Fultz II's avatar
Paul Fultz II committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    {
        auto ins  = mr.result;
        auto axis = get_axis(ins);
        std::vector<instruction_ref> args;
        fix([&](auto self, auto&& inputs) {
            for(auto&& i : inputs)
            {
                if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
                    self(i->inputs());
                else
                    args.push_back(i);
            }

        })(ins->inputs());
        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

303
void simplify_reshapes::apply(module& p) const
Paul's avatar
Paul committed
304
{
305
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
306
    {
307
308
309
310
311
312
313
314
        match::find_matches(p,
                            find_nop_reshapes{},
                            find_reshaper{},
                            find_transpose{},
                            find_concat_transpose{},
                            find_nested_slice{},
                            find_nested_concat{});
        dead_code_elimination{}.apply(p);
Paul's avatar
Paul committed
315
    }
Paul's avatar
Paul committed
316
317
}

Paul's avatar
Paul committed
318
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
319
} // namespace migraphx