simplify_algebra.cpp 3.25 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/program.hpp>
3
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
4
5
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
Paul's avatar
Paul committed
6
7
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
namespace migraphx {
Paul's avatar
Paul committed
10
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
Paul's avatar
Paul committed
13
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
Paul's avatar
Paul committed
14
15
16
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
    return match::name(op)(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)),
Paul's avatar
Paul committed
17
                                                   not_lit_broadcast().bind(std::move(y))));
Paul's avatar
Paul committed
18
19
20
21
22
}

struct find_mul_conv
{
    auto matcher() const
Paul's avatar
Paul committed
23
    {
Paul's avatar
Paul committed
24
        return match::name("mul")(match::either_arg(0, 1)(
Paul's avatar
Paul committed
25
            match::name("convolution")(match::used_once(),
Paul's avatar
Paul committed
26
                                       match::args(match::any(), match::is_constant().bind("w")))
Paul's avatar
Paul committed
27
28
                .bind("conv"),
            match::name("broadcast").bind("a")));
Paul's avatar
Paul committed
29
    }
Paul's avatar
Paul committed
30
31

    void apply(program& p, match::matcher_result r) const
Paul's avatar
Paul committed
32
    {
Paul's avatar
Paul committed
33
        auto ins      = r.result;
Paul's avatar
Paul committed
34
        auto conv_ins = r.instructions["conv"];
Paul's avatar
Paul committed
35
36
37
        auto a_ins    = r.instructions["a"];
        auto w_ins    = r.instructions["w"];

Paul's avatar
Paul committed
38
        auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
Paul's avatar
Paul committed
39
        if(broadcast_op.axis != 1)
Paul's avatar
Paul committed
40
41
            return;

Paul's avatar
Paul committed
42
43
44
45
46
        auto new_a = p.insert_instruction(
            ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
        auto new_mul  = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
        auto new_conv = p.insert_instruction(
            ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
Paul's avatar
Paul committed
47
        p.replace_instruction(ins, new_conv);
Paul's avatar
Paul committed
48
    }
Paul's avatar
Paul committed
49
50
51
52
};

struct find_add_lit_broadcast
{
Paul's avatar
Paul committed
53
54
    auto matcher() const
    {
Paul's avatar
Paul committed
55
        return match::name("add")(
Paul's avatar
Paul committed
56
            match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
Paul's avatar
Paul committed
57
58
59
60
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
61
62
63
64
65
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto y_ins = r.instructions["y"];
        auto a_ins = r.instructions["a"];
        auto b_ins = r.instructions["b"];
Paul's avatar
Paul committed
66
67
68
69
70
71
72
73
74

        if(a_ins->name() != b_ins->name())
            return;
        instruction_ref sumab;

        if(a_ins->name() == "broadcast")
        {
            if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
                return;
Paul's avatar
Paul committed
75
76
77
78
            auto op = a_ins->get_operator();
            auto presum =
                p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
            sumab = p.insert_instruction(ins, op, presum);
Paul's avatar
Paul committed
79
80
81
82
83
84
85
86
87
88
89
        }
        else
        {
            sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
        }

        auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
        p.replace_instruction(ins, op::add{}, sumxy, sumab);
    }
};

Paul's avatar
Paul committed
90
91
void simplify_algebra::apply(program& p) const
{
Paul's avatar
Paul committed
92
    // Run simplifications twice
Paul's avatar
Paul committed
93
    for(int i = 0; i < 2; i++)
Paul's avatar
Paul committed
94
        match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{});
Paul's avatar
Paul committed
95
}
Paul's avatar
Paul committed
96

Paul's avatar
Paul committed
97
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
98
} // namespace migraphx