"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2b88574d53f22477954ecbfcd5f0b770a5b5291c"
simplify_algebra.cpp 5.9 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/simplify_algebra.hpp>
Paul's avatar
Paul committed
2
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
3
#include <migraphx/program.hpp>
4
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
5
6
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
Paul's avatar
Paul committed
14
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
Paul's avatar
Paul committed
15
16
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
Paul's avatar
Paul committed
17
18
    return match::name(std::move(op))(match::either_arg(0, 1)(
        lit_broadcast().bind(std::move(x)), not_lit_broadcast().bind(std::move(y))));
Paul's avatar
Paul committed
19
20
}

Paul's avatar
Paul committed
21
22
auto conv_const_weights()
{
Paul's avatar
Paul committed
23
    return match::name("convolution")(match::used_once(),
Paul's avatar
Paul committed
24
                                      match::args(match::any(), match::is_constant().bind("w")));
Paul's avatar
Paul committed
25
26
}

Paul's avatar
Paul committed
27
28
29
struct find_mul_conv
{
    auto matcher() const
Paul's avatar
Paul committed
30
    {
Paul's avatar
Paul committed
31
32
        return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
                                                          match::name("broadcast").bind("a")));
Paul's avatar
Paul committed
33
    }
Paul's avatar
Paul committed
34
35

    void apply(program& p, match::matcher_result r) const
Paul's avatar
Paul committed
36
    {
Paul's avatar
Paul committed
37
        auto ins      = r.result;
Paul's avatar
Paul committed
38
        auto conv_ins = r.instructions["conv"];
Paul's avatar
Paul committed
39
40
41
        auto a_ins    = r.instructions["a"];
        auto w_ins    = r.instructions["w"];

Paul's avatar
Paul committed
42
        auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
Paul's avatar
Paul committed
43
        if(broadcast_op.axis != 1)
Paul's avatar
Paul committed
44
45
            return;

Paul's avatar
Paul committed
46
47
48
49
50
        auto new_a = p.insert_instruction(
            ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
        auto new_mul  = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
        auto new_conv = p.insert_instruction(
            ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
Paul's avatar
Paul committed
51
        p.replace_instruction(ins, new_conv);
Paul's avatar
Paul committed
52
    }
Paul's avatar
Paul committed
53
54
};

Paul's avatar
Paul committed
55
// a * (x + b) => a * x + a * b
Paul's avatar
Paul committed
56
57
58
59
60
struct find_mul_add
{
    auto matcher() const
    {
        return match::name("mul")(match::either_arg(0, 1)(
Paul's avatar
Paul committed
61
62
63
            match::name("add")(
                match::either_arg(0, 1)(
                    match::any().bind("x"),
Paul's avatar
Paul committed
64
                    match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
Paul's avatar
Paul committed
65
                match::none_of(match::args(match::is_constant(), match::is_constant())),
Paul's avatar
Paul committed
66
                match::used_once()),
Paul's avatar
Paul committed
67
            match::is_constant().bind("a")));
Paul's avatar
Paul committed
68
69
70
71
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
72
        auto ins   = r.result;
Paul's avatar
Paul committed
73
        auto a_ins = r.instructions["a"];
Paul's avatar
Paul committed
74
        auto b_ins = r.instructions["b"];
Paul's avatar
Paul committed
75
        auto x_ins = r.instructions["x"];
Paul's avatar
Paul committed
76
        assert(x_ins != b_ins);
Paul's avatar
Paul committed
77

Paul's avatar
Paul committed
78
79
80
        auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
        auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
        p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
Paul's avatar
Paul committed
81
82
83
    }
};

Paul's avatar
Paul committed
84
struct find_add_lit_broadcast
Paul's avatar
Paul committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
{
    auto matcher() const
    {
        return match::name("add")(
            match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto a_ins = r.instructions["a"];
        auto b_ins = r.instructions["b"];

        auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
        p.replace_instruction(ins, op::add{}, x_ins, sumab);
    }
};

struct find_double_add_lit_broadcast
Paul's avatar
Paul committed
105
{
Paul's avatar
Paul committed
106
107
    auto matcher() const
    {
Paul's avatar
Paul committed
108
        return match::name("add")(
Paul's avatar
Paul committed
109
            match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
Paul's avatar
Paul committed
110
111
112
113
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
114
115
116
117
118
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto y_ins = r.instructions["y"];
        auto a_ins = r.instructions["a"];
        auto b_ins = r.instructions["b"];
Paul's avatar
Paul committed
119
120
121

        instruction_ref sumab;

Paul's avatar
Paul committed
122
        if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
Paul's avatar
Paul committed
123
124
125
        {
            if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
                return;
Paul's avatar
Paul committed
126
127
128
129
            auto op = a_ins->get_operator();
            auto presum =
                p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
            sumab = p.insert_instruction(ins, op, presum);
Paul's avatar
Paul committed
130
131
132
133
134
135
136
137
138
139
140
        }
        else
        {
            sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
        }

        auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
        p.replace_instruction(ins, op::add{}, sumxy, sumab);
    }
};

Paul's avatar
Paul committed
141
142
143
144
struct find_inner_broadcast
{
    auto matcher() const
    {
Paul's avatar
Paul committed
145
146
        return match::name("mul", "add")(
            match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
Paul's avatar
Paul committed
147
148
149
150
151
152
153
154
155
156
157
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto y_ins = r.instructions["y"];

        auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
        auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());

Paul's avatar
Paul committed
158
        if(xbroadcast.axis != ybroadcast.axis)
Paul's avatar
Paul committed
159
160
            return;

Paul's avatar
Paul committed
161
162
        auto op = p.insert_instruction(
            ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
Paul's avatar
Paul committed
163
164
165
166
        p.replace_instruction(ins, xbroadcast, op);
    }
};

Paul's avatar
Paul committed
167
168
void simplify_algebra::apply(program& p) const
{
Paul's avatar
Paul committed
169
    // Run simplifications multiple times
Paul's avatar
Paul committed
170
    for(int i = 0; i < 4; i++)
Paul's avatar
Paul committed
171
    {
Paul's avatar
Paul committed
172
        match::find_matches(p,
Paul's avatar
Paul committed
173
                            find_inner_broadcast{},
Paul's avatar
Paul committed
174
175
176
177
                            find_double_add_lit_broadcast{},
                            find_add_lit_broadcast{},
                            find_mul_conv{},
                            find_mul_add{});
Paul's avatar
Paul committed
178
179
        dead_code_elimination{}.apply(p);
    }
Paul's avatar
Paul committed
180
}
Paul's avatar
Paul committed
181

Paul's avatar
Paul committed
182
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
183
} // namespace migraphx