propagate_constant.cpp 1.56 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/propagate_constant.hpp>
Paul's avatar
Paul committed
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/functional.hpp>
6
#include <unordered_set>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
namespace migraphx {
Paul's avatar
Paul committed
9
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
bool skip_propogate(instruction_ref ins)
Paul's avatar
Paul committed
12
{
Paul's avatar
Paul committed
13
    if(ins->name() == "contiguous")
Paul's avatar
Paul committed
14
        return skip_propogate(ins->inputs().front());
Paul's avatar
Paul committed
15
16
    auto&& s = ins->get_shape();
    if(s.broadcasted() and not s.scalar())
Paul's avatar
Paul committed
17
        return true;
Paul's avatar
Paul committed
18
    if(s.scalar() and s.elements() != 1)
Paul's avatar
Paul committed
19
20
21
        return true;
    return false;
}
Paul's avatar
Paul committed
22

23
void propagate_constant::apply(module& m) const
Paul's avatar
Paul committed
24
{
25
    for(auto i : iterator_for(m))
26
    {
Paul's avatar
Paul committed
27
        if(i->name() != "@literal")
28
            continue;
Paul's avatar
Paul committed
29
        if(i->outputs().empty())
30
31
            continue;
        fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
32
33
            std::unordered_set<instruction_ref> children(ins->outputs().begin(),
                                                         ins->outputs().end());
34
            for(auto child : children)
Paul's avatar
Paul committed
35
            {
Paul's avatar
Paul committed
36
                if(child->name() == "@literal" or skip_propogate(child))
37
                {
Paul's avatar
Paul committed
38
39
                    self(child);
                    continue;
Paul's avatar
Paul committed
40
                }
Paul's avatar
Paul committed
41
42
                auto r = child->eval();
                if(not r.empty())
43
                {
Paul's avatar
Paul committed
44
                    assert(r.get_shape() == child->get_shape());
45
46
                    auto l = m.add_literal(r.get_shape(), r.data());
                    self(m.replace_instruction(child, l));
47
                }
Paul's avatar
Paul committed
48
            }
49
50
        })(i);
    }
Paul's avatar
Paul committed
51
}
Paul's avatar
Paul committed
52

Paul's avatar
Paul committed
53
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
54
} // namespace migraphx