simplify_reshapes.cpp 6.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/matcher.hpp>
10
#include <migraphx/permutation.hpp>
Paul's avatar
Paul committed
11
12
#include <unordered_set>

Paul's avatar
Paul committed
13
namespace migraphx {
Paul's avatar
Paul committed
14
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
15

Paul's avatar
Paul committed
16
const auto& reshaper_names()
Paul's avatar
Paul committed
17
{
18
19
    // clang-format off
    static const std::unordered_set<std::string> names = {
20
        "flatten",
21
        "reshape",
22
23
24
        "contiguous",
        "squeeze",
        "unsqueeze"
25
26
    };
    // clang-format on
Paul's avatar
Paul committed
27
    return names;
Paul's avatar
Paul committed
28
29
}

Paul's avatar
Paul committed
30
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
31
32
33

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
34
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
35
        return ins;
Paul's avatar
Paul committed
36
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
37
38
39
40
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
41
42
}

43
44
45
46
47
48
49
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
50
    if(dims.empty())
51
        return true;
Paul's avatar
Paul committed
52
    if(dims.front() != 0)
53
        return false;
Paul's avatar
Paul committed
54
55
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
56
57
}

Paul's avatar
Paul committed
58
struct find_reshaper
Paul's avatar
Paul committed
59
{
Paul's avatar
Paul committed
60
    auto matcher() const
Paul's avatar
Paul committed
61
    {
Paul's avatar
Paul committed
62
63
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
64
65
    }

Paul's avatar
Paul committed
66
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
67
68
69
70
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
71
        {
Paul's avatar
Paul committed
72
73
74
75
76
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
77

Paul's avatar
Paul committed
78
79
80
81
82
83
84
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
85
            {
Paul's avatar
Paul committed
86
87
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
88
89
            }
        }
Paul's avatar
Paul committed
90
        if(r.first != r.second)
Paul's avatar
Paul committed
91
        {
Paul's avatar
Paul committed
92
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
93
        }
Paul's avatar
Paul committed
94
95
96
    }
};

Paul's avatar
Paul committed
97
98
99
100
101
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
102
103
104
105
106
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
Paul's avatar
Paul committed
107
        reshapes.insert("slice");
108
        reshapes.insert("transpose");
Paul's avatar
Paul committed
109
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
110
111
    }

Paul's avatar
Paul committed
112
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
113
114
115
116
117
118
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
119
120
121
122
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
123
124
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
125
126
    }

Paul's avatar
Paul committed
127
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
128
129
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
130
131
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
147
        {
Paul's avatar
Paul committed
148
            p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
Paul's avatar
Paul committed
149
        }
Paul's avatar
Paul committed
150
    }
Paul's avatar
Paul committed
151
152
153
154
155
156
};

struct find_concat_transpose
{
    auto matcher() const
    {
157
        return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
158
159
    }

Paul's avatar
Paul committed
160
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
161
162
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
163
        auto s   = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
164
        assert(s.transposed());
Paul's avatar
Paul committed
165
166
        auto op           = any_cast<op::concat>(ins->get_operator());
        auto permutation  = find_permutation(s);
Paul's avatar
Paul committed
167
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
168
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
169
170
171

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
172
173
174
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
                return p.insert_instruction(ins, op::transpose{permutation}, i);
            });
Paul's avatar
Paul committed
175
        auto concat = p.insert_instruction(ins, op, inputs);
Paul's avatar
Paul committed
176
177
        auto t      = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
178
179
180
181
        p.replace_instruction(ins, t);
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
struct find_nested_concat
{
    auto matcher() const
    {
        return match::name("concat")(match::any_of[match::inputs()](match::name("concat")));
    }

    static std::size_t get_axis(instruction_ref ins)
    {
        auto op = any_cast<op::concat>(ins->get_operator());
        return op.axis;
    }

    void apply(program& p, const match::matcher_result& mr) const
    {
        auto ins  = mr.result;
        auto axis = get_axis(ins);
        std::vector<instruction_ref> args;
        fix([&](auto self, auto&& inputs) {
            for(auto&& i : inputs)
            {
                if(i->name() == "concat" and get_axis(i) == axis and i->outputs().size() == 1)
                    self(i->inputs());
                else
                    args.push_back(i);
            }

        })(ins->inputs());
        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

Paul's avatar
Paul committed
214
215
void simplify_reshapes::apply(program& p) const
{
216
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
217
    {
218
219
220
221
222
223
224
225
226
227
228
229
230
        auto end = std::prev(p.end());
        for(auto ins : iterator_for(p))
        {
            if(ins == end and ins->name() == "contiguous")
                continue;
            // Skip possible dead instructions
            if(ins->outputs().empty() and ins != end)
                continue;
            match::find_matches(p,
                                ins,
                                find_nop_reshapes{},
                                find_reshaper{},
                                find_transpose{},
Paul Fultz II's avatar
Paul Fultz II committed
231
232
                                find_concat_transpose{},
                                find_nested_concat{});
233
        }
Paul's avatar
Paul committed
234
    }
Paul's avatar
Paul committed
235
236
}

Paul's avatar
Paul committed
237
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
238
} // namespace migraphx