simplify_reshapes.cpp 9.35 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
7
#include <migraphx/op/slice.hpp>
Paul's avatar
Paul committed
8
9
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/matcher.hpp>
11
#include <migraphx/permutation.hpp>
12
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
13
#include <unordered_set>
14
#include <migraphx/make_op.hpp>
15
#include <migraphx/tune_axis.hpp>
16

17
#include <map>
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
namespace migraphx {
Paul's avatar
Paul committed
20
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
21

Paul's avatar
Paul committed
22
const auto& reshaper_names()
Paul's avatar
Paul committed
23
{
24
25
    // clang-format off
    static const std::unordered_set<std::string> names = {
26
        "flatten",
27
        "reshape",
28
29
30
        "contiguous",
        "squeeze",
        "unsqueeze"
31
32
    };
    // clang-format on
Paul's avatar
Paul committed
33
    return names;
Paul's avatar
Paul committed
34
35
}

Paul's avatar
Paul committed
36
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
37
38
39

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
40
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
41
        return ins;
Paul's avatar
Paul committed
42
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
43
44
45
46
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
47
48
}

49
50
51
52
53
54
55
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
56
    if(dims.empty())
57
        return true;
Paul's avatar
Paul committed
58
    if(dims.front() != 0)
59
        return false;
Paul's avatar
Paul committed
60
61
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
62
63
}

Paul's avatar
Paul committed
64
struct find_reshaper
Paul's avatar
Paul committed
65
{
Paul's avatar
Paul committed
66
    auto matcher() const
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
69
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
70
71
    }

72
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
73
74
75
76
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
77
        {
Paul's avatar
Paul committed
78
79
80
81
82
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
83

Paul's avatar
Paul committed
84
85
86
87
88
89
90
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
91
            {
Paul's avatar
Paul committed
92
93
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
94
95
            }
        }
Paul's avatar
Paul committed
96
        if(r.first != r.second)
Paul's avatar
Paul committed
97
        {
Paul's avatar
Paul committed
98
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
99
        }
Paul's avatar
Paul committed
100
101
102
    }
};

Paul's avatar
Paul committed
103
104
105
106
107
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
108
109
110
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
Paul Fultz II's avatar
Paul Fultz II committed
111
        reshapes.insert("convert");
112
113
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
Paul's avatar
Paul committed
114
        reshapes.insert("slice");
115
        reshapes.insert("transpose");
Paul's avatar
Paul committed
116
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
117
118
    }

119
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
120
121
122
123
124
125
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
126
127
128
129
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
130
131
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
132
133
    }

134
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
135
136
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
137
138
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
154
        {
155
            p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
Paul's avatar
Paul committed
156
        }
Paul's avatar
Paul committed
157
    }
Paul's avatar
Paul committed
158
159
};

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
struct find_nested_slice
{
    auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }

    using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>;

    static axes_map get_axes(instruction_ref ins)
    {
        axes_map result;
        auto op = any_cast<op::slice>(ins->get_operator());
        for(std::size_t i = 0; i < op.axes.size(); i++)
        {
            result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
        }
        return result;
    }

    static axes_map merge(const axes_map& m1, const axes_map& m2)
    {
        axes_map result;
        // Non overlapping
        for(auto&& p : m1)
        {
            if(contains(m2, p.first))
                continue;
            result[p.first] = p.second;
        }
        for(auto&& p : m2)
        {
            if(contains(m1, p.first))
                continue;
            result[p.first] = p.second;
        }
        // Overlapping
        for(auto&& p1 : m1)
        {
            if(not contains(m2, p1.first))
                continue;
            auto&& v1        = p1.second;
            auto&& v2        = m2.at(p1.first);
            auto start       = v1.first + v2.first;
            auto end         = start + (v2.second - v2.first);
            result[p1.first] = std::make_pair(start, end);
        }
        return result;
    }

207
    void apply(module& p, const match::matcher_result& mr) const
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    {
        auto ins   = mr.result;
        auto slice = ins->inputs().front();
        auto input = slice->inputs().front();

        auto a1 = get_axes(ins);
        auto a2 = get_axes(slice);

        auto axes = merge(a2, a1);

        auto op = op::slice{};
        for(auto&& pp : axes)
        {
            op.axes.push_back(pp.first);
            op.starts.push_back(pp.second.first);
            op.ends.push_back(pp.second.second);
        }
        p.replace_instruction(ins, op, input);
    }
};

Paul's avatar
Paul committed
229
230
231
232
struct find_concat_transpose
{
    auto matcher() const
    {
233
        return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
234
235
    }

236
    void apply(module& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
237
    {
Shucai Xiao's avatar
Shucai Xiao committed
238
239
240
        auto ins          = mr.result;
        auto trans_inputs = ins->inputs();
        auto s            = trans_inputs.front()->get_shape();
Paul's avatar
Paul committed
241
        assert(s.transposed());
Shucai Xiao's avatar
Shucai Xiao committed
242
243
244
245
246
247
248
249
250
251
252
253
254
        auto op          = any_cast<op::concat>(ins->get_operator());
        auto permutation = find_permutation(s);

        // permutation should be the same for all inputs
        if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
               return (find_permutation(in->get_shape()) == permutation);
           }))
        {
            return;
        }

        // axis could be a negative value
        int64_t n_dim = static_cast<int64_t>(s.lens().size());
255
        op.axis       = tune_axis(n_dim, op.axis, op.name());
Shucai Xiao's avatar
Shucai Xiao committed
256

Paul's avatar
Paul committed
257
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
258
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
259
260
261

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
262
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
263
                return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
Paul's avatar
Paul committed
264
            });
Paul's avatar
Paul committed
265
        auto concat = p.insert_instruction(ins, op, inputs);
266
        auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
Paul's avatar
Paul committed
267
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
268
269
270
271
        p.replace_instruction(ins, t);
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
272
273
274
275
276
277
278
279
280
281
282
283
284
struct find_nested_concat
{
    auto matcher() const
    {
        return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
    }

    static std::size_t get_axis(instruction_ref ins)
    {
        auto op = any_cast<op::concat>(ins->get_operator());
        return op.axis;
    }

285
    void apply(module& p, const match::matcher_result& mr) const
Paul Fultz II's avatar
Paul Fultz II committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    {
        auto ins  = mr.result;
        auto axis = get_axis(ins);
        std::vector<instruction_ref> args;
        fix([&](auto self, auto&& inputs) {
            for(auto&& i : inputs)
            {
                if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
                    self(i->inputs());
                else
                    args.push_back(i);
            }

        })(ins->inputs());
        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

304
void simplify_reshapes::apply(module& p) const
Paul's avatar
Paul committed
305
{
306
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
307
    {
308
309
310
311
312
313
314
315
        match::find_matches(p,
                            find_nop_reshapes{},
                            find_reshaper{},
                            find_transpose{},
                            find_concat_transpose{},
                            find_nested_slice{},
                            find_nested_concat{});
        dead_code_elimination{}.apply(p);
Paul's avatar
Paul committed
316
    }
Paul's avatar
Paul committed
317
318
}

Paul's avatar
Paul committed
319
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
320
} // namespace migraphx