constant_propagate.cpp 1.32 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#include <migraphx/constant_propagate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/functional.hpp>
Paul's avatar
Paul committed
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
9
10
11
12
13
14
15
16

struct match_const_add
{
    auto matcher() const
    {
        return match::name("add")(match::args(match::name("@literal"), match::name("@literal")));
    }

Paul's avatar
Paul committed
17
    void apply(program& p, const match::matcher_result& r) const
Paul's avatar
Paul committed
18
    {
Paul's avatar
Paul committed
19
        auto ins  = r.result;
Paul's avatar
Paul committed
20
21
22
23
24
25
26
27
        auto arg1 = ins->inputs().at(0)->get_literal();
        auto arg2 = ins->inputs().at(1)->get_literal();

        auto sum = p.add_literal(transform(arg1, arg2, [](auto x, auto y) { return x + y; }));
        p.replace_instruction(ins, sum);
    }
};

Paul's avatar
Paul committed
28
29
30
void constant_propagate::apply(program& p) const 
{
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
31
        if (not ins->get_shape().broadcasted() and ins->name() != "@literal")
Paul's avatar
Paul committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        {
            auto r = ins->eval();
            if (not r.empty())
            {
                auto l = p.add_literal(r.get_shape(), r.data());
                p.replace_instruction(ins, l);
                return;
            }
        }
        auto children = ins->inputs();
        for(auto child:children)
            self(child);
    })(std::prev(p.end()));
}
Paul's avatar
Paul committed
46

Paul's avatar
Paul committed
47
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
48
} // namespace migraphx