simplify_reshapes.cpp 6.81 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/as_shape.hpp>
5
#include <migraphx/op/transpose.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/op/concat.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/matcher.hpp>
Paul's avatar
Paul committed
10
11
#include <unordered_set>

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
const auto& reshaper_names()
Paul's avatar
Paul committed
16
{
17
18
19
    // clang-format off
    static const std::unordered_set<std::string> names = {
        "reshape",
20
21
22
        "contiguous",
        "squeeze",
        "unsqueeze"
23
24
    };
    // clang-format on
Paul's avatar
Paul committed
25
    return names;
Paul's avatar
Paul committed
26
27
}

Paul's avatar
Paul committed
28
bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
Paul's avatar
Paul committed
29
30
31

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
32
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
33
        return ins;
Paul's avatar
Paul committed
34
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
35
36
37
38
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
39
40
}

41
42
43
44
45
46
47
48
49
auto get_transpose_dims(instruction_ref ins)
{
    return any_cast<const op::transpose&>(ins->get_operator()).dims;
}

std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
    std::vector<int64_t> result(dims.size());
    assert(dims.size() == permutation.size());
Paul's avatar
Paul committed
50
    for(std::size_t i = 0; i < dims.size(); i++)
51
    {
Paul's avatar
Paul committed
52
        result[i] = dims[permutation[i]];
53
54
55
56
57
58
    }
    return result;
}

bool is_no_transpose(const std::vector<int64_t>& dims)
{
Paul's avatar
Paul committed
59
    if(dims.empty())
60
        return true;
Paul's avatar
Paul committed
61
    if(dims.front() != 0)
62
        return false;
Paul's avatar
Paul committed
63
64
    return std::adjacent_find(
               dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
65
66
}

Paul's avatar
Paul committed
67
template <class Vector, class Op>
Paul's avatar
Paul committed
68
69
70
71
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
    std::vector<std::int64_t> result(data.size());
    std::iota(result.begin(), result.end(), 0);
Paul's avatar
Paul committed
72
    std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
Paul's avatar
Paul committed
73
74
75
    return result;
}

Paul's avatar
Paul committed
76
77
78
79
80
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
    return sort_permutation(permutation, std::less<>{});
}

Paul's avatar
Paul committed
81
82
83
84
85
std::vector<int64_t> find_permutation(const shape& s)
{
    return sort_permutation(s.strides(), std::greater<>{});
}

Paul's avatar
Paul committed
86
struct find_reshaper
Paul's avatar
Paul committed
87
{
Paul's avatar
Paul committed
88
    auto matcher() const
Paul's avatar
Paul committed
89
    {
Paul's avatar
Paul committed
90
91
        return match::name(reshaper_names())(
            match::any_of[match::outputs()](match::name(reshaper_names())));
Paul's avatar
Paul committed
92
93
    }

Paul's avatar
Paul committed
94
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
95
96
97
98
    {
        auto ins = mr.result;
        std::vector<instruction_ref> reshapes{ins};
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
99
        {
Paul's avatar
Paul committed
100
101
102
103
104
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
        }
Paul's avatar
Paul committed
105

Paul's avatar
Paul committed
106
107
108
109
110
111
112
        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
        for(auto start : iterator_for(reshapes))
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->get_shape() == (*start)->get_shape() and i != (*start);
            });
            if(last != reshapes.rend())
Paul's avatar
Paul committed
113
            {
Paul's avatar
Paul committed
114
115
                r = std::make_pair(*start, *last);
                break;
Paul's avatar
Paul committed
116
117
            }
        }
Paul's avatar
Paul committed
118
        if(r.first != r.second)
Paul's avatar
Paul committed
119
        {
Paul's avatar
Paul committed
120
            p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
121
        }
Paul's avatar
Paul committed
122
123
124
    }
};

Paul's avatar
Paul committed
125
126
127
128
129
130
131
struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
        reshapes.insert("transpose");
        reshapes.insert("slice");
Paul's avatar
Paul committed
132
        return match::name(reshapes)(match::same_shape(match::arg(0)));
Paul's avatar
Paul committed
133
134
    }

Paul's avatar
Paul committed
135
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
136
137
138
139
140
141
    {
        auto ins = mr.result;
        p.replace_instruction(ins, ins->inputs().front());
    }
};

Paul's avatar
Paul committed
142
143
144
145
struct find_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
146
147
        return match::name("transpose")(match::none_of(
            match::skip_output(match::name("contiguous"))(match::name("transpose"))));
Paul's avatar
Paul committed
148
149
    }

Paul's avatar
Paul committed
150
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
151
152
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
153
154
        auto x   = ins;
        auto t   = ins;
Paul's avatar
Paul committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        std::vector<std::int64_t> dims(ins->get_shape().lens().size());
        std::iota(dims.begin(), dims.end(), 0);
        do
        {
            dims = reorder_dims(get_transpose_dims(t), dims);
            x    = t;
            t    = find_transpose_input(x);
        } while(x != t and t->name() == "transpose");
        if(t == ins or t->name() != "transpose")
            return;
        if(is_no_transpose(dims))
        {
            p.replace_instruction(ins, t->inputs().front());
        }
        else
Paul's avatar
Paul committed
170
        {
Paul's avatar
Paul committed
171
            p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
Paul's avatar
Paul committed
172
        }
Paul's avatar
Paul committed
173
    }
Paul's avatar
Paul committed
174
175
176
177
178
179
};

struct find_concat_transpose
{
    auto matcher() const
    {
Paul's avatar
Paul committed
180
        return match::name("concat")(match::same_input_shapes(),
Paul's avatar
Paul committed
181
                                     match::all_of[match::inputs()](match::transpose_shape()));
Paul's avatar
Paul committed
182
183
    }

Paul's avatar
Paul committed
184
    void apply(program& p, const match::matcher_result& mr) const
Paul's avatar
Paul committed
185
186
    {
        auto ins = mr.result;
Paul's avatar
Paul committed
187
        auto s   = ins->inputs().front()->get_shape();
Paul's avatar
Paul committed
188
        assert(s.transposed());
Paul's avatar
Paul committed
189
190
        auto op           = any_cast<op::concat>(ins->get_operator());
        auto permutation  = find_permutation(s);
Paul's avatar
Paul committed
191
        auto ipermutation = invert_permutation(permutation);
Paul's avatar
Paul committed
192
        op.axis           = ipermutation[op.axis];
Paul's avatar
Paul committed
193
194
195

        std::vector<instruction_ref> inputs;
        std::transform(
Paul's avatar
Paul committed
196
            ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
Paul's avatar
Paul committed
197
                if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
Paul's avatar
Paul committed
198
                    return i->inputs().front();
Paul's avatar
Paul committed
199
200
                return p.insert_instruction(ins, op::transpose{permutation}, i);
            });
Paul's avatar
Paul committed
201
        auto concat = p.insert_instruction(ins, op, inputs);
Paul's avatar
Paul committed
202
203
        auto t      = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
        assert(ins->get_shape().lens() == t->get_shape().lens());
Paul's avatar
Paul committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        p.replace_instruction(ins, t);
    }
};

void simplify_reshapes::apply(program& p) const
{
    auto end = std::prev(p.end());
    for(auto ins : iterator_for(p))
    {
        if(ins == end and ins->name() == "contiguous")
            continue;
        // Skip possible dead instructions
        if(ins->outputs().empty() and ins != end)
            continue;
Khalique's avatar
Khalique committed
218
219
220
221
222
223
        match::find_matches(p,
                            ins,
                            find_nop_reshapes{},
                            find_reshaper{},
                            find_transpose{},
                            find_concat_transpose{});
Paul's avatar
Paul committed
224
    }
Paul's avatar
Paul committed
225
226
}

Paul's avatar
Paul committed
227
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
228
} // namespace migraphx