simplify_reshapes.cpp 1.74 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraph/simplify_reshapes.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <unordered_set>

namespace migraph {

bool is_reshaper(const std::string& name)
{
13
14
15
16
17
18
19
20
    // clang-format off
    static const std::unordered_set<std::string> names = {
        "reshape",
        "transpose",
        // "broadcast",
        "contiguous"
    };
    // clang-format on
Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    return contains(names, name);
}

void simplify_reshapes::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        if(not is_reshaper(ins->op.name()))
            continue;
        if(ins->output.size() != 1)
            continue;
        if(is_reshaper(ins->output.front()->op.name()))
            continue;
        // Gather reshapes
        std::vector<instruction_ref> reshapes{ins};
Paul's avatar
Paul committed
36
        while(is_reshaper(reshapes.back()->op.name()))
Paul's avatar
Paul committed
37
38
39
40
41
42
43
        {
            assert(!reshapes.back()->arguments.empty());
            assert(p.has_instruction(reshapes.back()->arguments.front()));
            reshapes.push_back(reshapes.back()->arguments.front());
        }

        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
Paul's avatar
Paul committed
44
        for(auto start : iterator_for(reshapes))
Paul's avatar
Paul committed
45
46
47
48
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                return i->result == (*start)->result and i != (*start);
            });
Paul's avatar
Paul committed
49
50
            if(last != reshapes.rend())
            {
Paul's avatar
Paul committed
51
52
53
54
                r = std::make_pair(*start, *last);
                break;
            }
        }
Paul's avatar
Paul committed
55
56
        if(r.first != r.second)
        {
Paul's avatar
Paul committed
57
58
59
60
61
62
            p.replace_instruction(r.first, r.second);
        }
    }
}

} // namespace migraph