simplify_algebra.cpp 1.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <migraph/simplify_algebra.hpp>
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/matcher.hpp>
#include <migraph/literal.hpp>

namespace migraph {

struct find_add_lit_broadcast
{
    auto lit_broadcast() const
    {
        return match::any_of(match::name("@literal"), match::name("broadcast"));
    }
    auto not_lit_broadcast() const
    {
        return match::none_of(match::name("@literal"), match::name("broadcast"));
    }
    auto add_lit_broadcast(std::string x, std::string y) const
    {
Paul's avatar
Paul committed
21
22
        return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)),
                                                          not_lit_broadcast().bind(std::move(y))));
Paul's avatar
Paul committed
23
24
25
    }
    auto matcher() const
    {
Paul's avatar
Paul committed
26
27
        return match::name("add")(
            match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y")));
Paul's avatar
Paul committed
28
29
30
31
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
32
33
34
35
36
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto y_ins = r.instructions["y"];
        auto a_ins = r.instructions["a"];
        auto b_ins = r.instructions["b"];
Paul's avatar
Paul committed
37
38
39
40
41
42
43
44
45

        if(a_ins->name() != b_ins->name())
            return;
        instruction_ref sumab;

        if(a_ins->name() == "broadcast")
        {
            if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
                return;
Paul's avatar
Paul committed
46
47
48
49
            auto op = a_ins->get_operator();
            auto presum =
                p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
            sumab = p.insert_instruction(ins, op, presum);
Paul's avatar
Paul committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        }
        else
        {
            sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
        }

        auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
        p.replace_instruction(ins, op::add{}, sumxy, sumab);
    }
};

void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }

} // namespace migraph