"vscode:/vscode.git/clone" did not exist on "e6812b97164aaedd8a33c051201b12efc1b4d0e3"
simplify_reshapes.cpp 9.21 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
7
#include <migraphx/op/slice.hpp>
Paul's avatar
Paul committed
8
9
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/matcher.hpp>
11
#include <migraphx/permutation.hpp>
12
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
13
#include <unordered_set>
14
#include <map>
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
namespace migraphx {
Paul's avatar
Paul committed
17
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
const auto& reshaper_names()
Paul's avatar
Paul committed
20
{
21
22
    // clang-format off
    static const std::unordered_set<std::string> names = {
23
        "flatten",
24
        "reshape",
25
26
27
        "contiguous",
        "squeeze",
        "unsqueeze"
28
29
    };
    // clang-format on
Paul's avatar
Paul committed
30
    return names;
Paul's avatar
Paul committed
31
32
}

Paul's avatar
Paul committed
33
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
34
35
36

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
37
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
38
        return ins;
Paul's avatar
Paul committed
39
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
40
41
42
43
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
44
45
}

46
47
48
49
50
51
52
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
53
    if(dims.empty())
54
        return true;
Paul's avatar
Paul committed
55
    if(dims.front() != 0)
56
        return false;
Paul's avatar
Paul committed
57
58
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
59
60
}

Paul's avatar
Paul committed
61
struct find_reshaper
Paul's avatar
Paul committed
62
{
Paul's avatar
Paul committed
63
    auto matcher() const
Paul's avatar
Paul committed
64
    {
Paul's avatar
Paul committed
65
66
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
67
68
    }

Paul's avatar
Paul committed
69
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
70
71
72
73
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
74
        {
Paul's avatar
Paul committed
75
76
77
78
79
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
80

Paul's avatar
Paul committed
81
82
83
84
85
86
87
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
88
            {
Paul's avatar
Paul committed
89
90
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
91
92
            }
        }
Paul's avatar
Paul committed
93
        if(r.first != r.second)
Paul's avatar
Paul committed
94
        {
Paul's avatar
Paul committed
95
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
96
        }
Paul's avatar
Paul committed
97
98
99
    }
};

Paul's avatar
Paul committed
100
101
102
103
104
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
105
106
107
108
109
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
Paul's avatar
Paul committed
110
        reshapes.insert("slice");
111
        reshapes.insert("transpose");
Paul's avatar
Paul committed
112
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
113
114
    }

Paul's avatar
Paul committed
115
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
116
117
118
119
120
121
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
122
123
124
125
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
126
127
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
128
129
    }

Paul's avatar
Paul committed
130
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
131
132
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
133
134
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
150
        {
Paul's avatar
Paul committed
151
            p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
Paul's avatar
Paul committed
152
        }
Paul's avatar
Paul committed
153
    }
Paul's avatar
Paul committed
154
155
};

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
struct find_nested_slice
{
    auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }

    using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>;

    static axes_map get_axes(instruction_ref ins)
    {
        axes_map result;
        auto op = any_cast<op::slice>(ins->get_operator());
        for(std::size_t i = 0; i < op.axes.size(); i++)
        {
            result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
        }
        return result;
    }

    static axes_map merge(const axes_map& m1, const axes_map& m2)
    {
        axes_map result;
        // Non overlapping
        for(auto&& p : m1)
        {
            if(contains(m2, p.first))
                continue;
            result[p.first] = p.second;
        }
        for(auto&& p : m2)
        {
            if(contains(m1, p.first))
                continue;
            result[p.first] = p.second;
        }
        // Overlapping
        for(auto&& p1 : m1)
        {
            if(not contains(m2, p1.first))
                continue;
            auto&& v1        = p1.second;
            auto&& v2        = m2.at(p1.first);
            auto start       = v1.first + v2.first;
            auto end         = start + (v2.second - v2.first);
            result[p1.first] = std::make_pair(start, end);
        }
        return result;
    }

    void apply(program& p, const match::matcher_result& mr) const
    {
        auto ins   = mr.result;
        auto slice = ins->inputs().front();
        auto input = slice->inputs().front();

        auto a1 = get_axes(ins);
        auto a2 = get_axes(slice);

        auto axes = merge(a2, a1);

        auto op = op::slice{};
        for(auto&& pp : axes)
        {
            op.axes.push_back(pp.first);
            op.starts.push_back(pp.second.first);
            op.ends.push_back(pp.second.second);
        }
        p.replace_instruction(ins, op, input);
    }
};

Paul's avatar
Paul committed
225
226
227
228
struct find_concat_transpose
{
    auto matcher() const
    {
229
        return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
230
231
    }

Paul's avatar
Paul committed
232
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
233
    {
Shucai Xiao's avatar
Shucai Xiao committed
234
235
236
        auto ins          = mr.result;
        auto trans_inputs = ins->inputs();
        auto s            = trans_inputs.front()->get_shape();
Paul's avatar
Paul committed
237
        assert(s.transposed());
Shucai Xiao's avatar
Shucai Xiao committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        auto op          = any_cast<op::concat>(ins->get_operator());
        auto permutation = find_permutation(s);

        // permutation should be the same for all inputs
        if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
               return (find_permutation(in->get_shape()) == permutation);
           }))
        {
            return;
        }

        // axis could be a negative value
        int64_t n_dim = static_cast<int64_t>(s.lens().size());
        op.axis       = (op.axis < 0) ? (op.axis + n_dim) : op.axis;

Paul's avatar
Paul committed
253
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
254
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
255
256
257

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
258
259
260
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
                return p.insert_instruction(ins, op::transpose{permutation}, i);
            });
Paul's avatar
Paul committed
261
        auto concat = p.insert_instruction(ins, op, inputs);
Paul's avatar
Paul committed
262
263
        auto t      = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
264
265
266
267
        p.replace_instruction(ins, t);
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
struct find_nested_concat
{
    auto matcher() const
    {
        return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
    }

    static std::size_t get_axis(instruction_ref ins)
    {
        auto op = any_cast<op::concat>(ins->get_operator());
        return op.axis;
    }

    void apply(program& p, const match::matcher_result& mr) const
    {
        auto ins  = mr.result;
        auto axis = get_axis(ins);
        std::vector<instruction_ref> args;
        fix([&](auto self, auto&& inputs) {
            for(auto&& i : inputs)
            {
                if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
                    self(i->inputs());
                else
                    args.push_back(i);
            }

        })(ins->inputs());
        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

Paul's avatar
Paul committed
300
301
void simplify_reshapes::apply(program& p) const
{
302
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
303
    {
304
305
306
307
308
309
310
311
        match::find_matches(p,
                            find_nop_reshapes{},
                            find_reshaper{},
                            find_transpose{},
                            find_concat_transpose{},
                            find_nested_slice{},
                            find_nested_concat{});
        dead_code_elimination{}.apply(p);
Paul's avatar
Paul committed
312
    }
Paul's avatar
Paul committed
313
314
}

Paul's avatar
Paul committed
315
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
316
} // namespace migraphx