"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4e960940e3925556d033965d6287efef5139109f"
simplify_reshapes.cpp 2.44 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
8
#include <unordered_set>

Paul's avatar
Paul committed
9
namespace migraphx {
Paul's avatar
Paul committed
10
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
// Reshapers that can't handle nonstandard input shapes
bool is_nonstandard_reshaper(instruction_ref ins)
{
    // clang-format off
    static const std::unordered_set<std::string> names = {
17
18
        "reshape",
        "contiguous"
Paul's avatar
Paul committed
19
20
21
22
23
24
    };
    // clang-format on
    return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous";
}

bool is_reshaper(instruction_ref ins)
Paul's avatar
Paul committed
25
{
26
27
28
29
30
31
32
33
    // clang-format off
    static const std::unordered_set<std::string> names = {
        "reshape",
        "transpose",
        // "broadcast",
        "contiguous"
    };
    // clang-format on
Paul's avatar
Paul committed
34
    return contains(names, ins->name()) and not is_nonstandard_reshaper(ins);
Paul's avatar
Paul committed
35
36
37
38
39
40
}

void simplify_reshapes::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
41
        if(not is_reshaper(ins))
Paul's avatar
Paul committed
42
            continue;
Paul's avatar
Paul committed
43
        if(ins->outputs().size() != 1)
Paul's avatar
Paul committed
44
            continue;
Paul's avatar
Paul committed
45
        if(is_reshaper(ins->outputs().front()))
Paul's avatar
Paul committed
46
47
48
            continue;
        // Gather reshapes
        std::vector<instruction_ref> reshapes{ins};
Paul's avatar
Paul committed
49
        while(is_reshaper(reshapes.back()))
Paul's avatar
Paul committed
50
        {
Paul's avatar
Paul committed
51
52
            assert(!reshapes.back()->inputs().empty());
            assert(p.has_instruction(reshapes.back()->inputs().front()));
Paul's avatar
Paul committed
53
54
            auto input = reshapes.back()->inputs().front();
            reshapes.push_back(input);
Paul's avatar
Paul committed
55
56
57
        }

        std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
Paul's avatar
Paul committed
58
        for(auto start : iterator_for(reshapes))
Paul's avatar
Paul committed
59
60
        {
            auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
Paul's avatar
Paul committed
61
                return i->get_shape() == (*start)->get_shape() and i != (*start);
Paul's avatar
Paul committed
62
            });
Paul's avatar
Paul committed
63
64
            if(last != reshapes.rend())
            {
Paul's avatar
Paul committed
65
66
67
68
                r = std::make_pair(*start, *last);
                break;
            }
        }
Paul's avatar
Paul committed
69
70
        if(r.first != r.second)
        {
Paul's avatar
Paul committed
71
72
73
            p.replace_instruction(r.first, r.second);
        }
    }
Paul's avatar
Paul committed
74
75
76
77
78
79
80
    // Replace all reshapes with as_shape
    for(auto ins : iterator_for(p))
    {
        if(ins->name() != "reshape")
            continue;
        p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
    }
Paul's avatar
Paul committed
81
82
}

Paul's avatar
Paul committed
83
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
84
} // namespace migraphx