"tools/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "63f233ad163cec493886bd306dc831637cf92c60"
simplify_reshapes.cpp 3.26 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
8
#include <unordered_set>

Paul's avatar
Paul committed
9
namespace migraphx {
Paul's avatar
Paul committed
10
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
bool is_reshaper(instruction_ref ins)
Paul's avatar
Paul committed
13
{
14
15
16
17
18
19
    // clang-format off
    static const std::unordered_set<std::string> names = {
        "reshape",
        "contiguous"
    };
    // clang-format on
Paul's avatar
Paul committed
20
21
22
23
24
    return contains(names, ins->name());
}

bool is_transpose_output(instruction_ref ins)
{
Paul's avatar
Paul committed
25
    if(ins->outputs().size() != 1)
Paul's avatar
Paul committed
26
        return false;
Paul's avatar
Paul committed
27
    if(ins->outputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
28
29
30
31
32
33
        return is_transpose_output(ins->outputs().front());
    return ins->outputs().front()->name() == "transpose";
}

instruction_ref find_transpose_input(instruction_ref ins)
{
Paul's avatar
Paul committed
34
    if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
35
        return ins;
Paul's avatar
Paul committed
36
    if(ins->inputs().front()->name() == "contiguous")
Paul's avatar
Paul committed
37
38
39
40
        return find_transpose_input(ins->inputs().front());
    if(ins->inputs().front()->name() == "transpose")
        return ins->inputs().front();
    return ins;
Paul's avatar
Paul committed
41
42
43
44
}

void simplify_reshapes::apply(program& p) const
{
Paul's avatar
Paul committed
45
    auto end = std::prev(p.end());
Paul's avatar
Paul committed
46
47
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
48
        if(ins->outputs().empty() and ins != end)
Paul's avatar
Paul committed
49
            continue;
Paul's avatar
Paul committed
50
        if(is_reshaper(ins))
Paul's avatar
Paul committed
51
        {
Paul's avatar
Paul committed
52
53
54
55
56
57
58
59
60
61
62
            if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
                continue;
            // Gather reshapes
            std::vector<instruction_ref> reshapes{ins};
            while(is_reshaper(reshapes.back()))
            {
                assert(!reshapes.back()->inputs().empty());
                assert(p.has_instruction(reshapes.back()->inputs().front()));
                auto input = reshapes.back()->inputs().front();
                reshapes.push_back(input);
            }
Paul's avatar
Paul committed
63

Paul's avatar
Paul committed
64
65
            std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
            for(auto start : iterator_for(reshapes))
Paul's avatar
Paul committed
66
            {
Paul's avatar
Paul committed
67
68
69
70
71
72
73
74
75
76
77
78
                auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
                    return i->get_shape() == (*start)->get_shape() and i != (*start);
                });
                if(last != reshapes.rend())
                {
                    r = std::make_pair(*start, *last);
                    break;
                }
            }
            if(r.first != r.second)
            {
                p.replace_instruction(r.first, r.second);
Paul's avatar
Paul committed
79
80
            }
        }
Paul's avatar
Paul committed
81
        else if(ins->name() == "transpose")
Paul's avatar
Paul committed
82
        {
Paul's avatar
Paul committed
83
            if(is_transpose_output(ins))
Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
                continue;
            auto x = ins;
            auto t = ins;
            do
            {
                x = t;
                t = find_transpose_input(x);
            } while(x != t and t->name() == "transpose");
Paul's avatar
Paul committed
92
            if(t == ins or t->name() != "transpose")
Paul's avatar
Paul committed
93
94
                continue;
            p.replace_instruction(ins, t->inputs().front());
Paul's avatar
Paul committed
95
96
        }
    }
Paul's avatar
Paul committed
97
98
99
100
101
102
103
    // Replace all reshapes with as_shape
    for(auto ins : iterator_for(p))
    {
        if(ins->name() != "reshape")
            continue;
        p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
    }
Paul's avatar
Paul committed
104
105
}

Paul's avatar
Paul committed
106
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
107
} // namespace migraphx