rewrite_batchnorm.cpp 2.37 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/rewrite_batchnorm.hpp>
Paul's avatar
Paul committed
2
3
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
#include <migraphx/op/batch_norm_inference.hpp>
5
6
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/op/mul.hpp>
Paul's avatar
Paul committed
8
#include <migraphx/iterator_for.hpp>
9
#include <migraphx/ranges.hpp>
10
11
#include <migraphx/make_op.hpp>

Paul's avatar
Paul committed
12
#include <migraphx/dfor.hpp>
13

Paul's avatar
Paul committed
14
namespace migraphx {
Paul's avatar
Paul committed
15
inline namespace MIGRAPHX_INLINE_NS {
16

17
void rewrite_batchnorm::apply(module& m) const
18
{
19
    for(auto ins : iterator_for(m))
20
    {
Paul's avatar
Paul committed
21
        if(ins->name() != "batch_norm_inference")
22
            continue;
23
        // Get scale, bias, mean, variance from inputs
Paul's avatar
Paul committed
24
25
26
27
        auto gamma    = ins->inputs()[1]->eval();
        auto bias     = ins->inputs()[2]->eval();
        auto mean     = ins->inputs()[3]->eval();
        auto variance = ins->inputs()[4]->eval();
28
        if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
29
30
            continue;

Shucai Xiao's avatar
Shucai Xiao committed
31
32
        std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens();
        shape s{ins->get_shape().type(), lens};
33
        // Get epsilon
34
        auto bn_op   = any_cast<op::batch_norm_inference>(ins->get_operator());
35
        auto epsilon = bn_op.epsilon;
Paul's avatar
Paul committed
36
37
38
39

        argument a{s};
        argument b{s};
        visit_all(gamma, bias, mean, variance, a, b)(
Paul's avatar
Paul committed
40
41
42
43
44
            [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
                dfor(a.get_shape().elements())(
                    [&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
                dfor(b.get_shape().elements())([&](std::size_t c) {
                    b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
Paul's avatar
Paul committed
45
                });
46
            });
Paul's avatar
Paul committed
47

Paul's avatar
Paul committed
48
        auto broadcast   = op::broadcast{1, ins->get_shape().lens()};
49
50
51
52
53
54
55
        auto a_ins       = m.add_literal({a.get_shape(), a.data()});
        auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
        auto mul   = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
        auto b_ins = m.add_literal({b.get_shape(), b.data()});
        auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
        auto add         = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
        m.replace_instruction(ins, add);
56
57
    }
}
58

Paul's avatar
Paul committed
59
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
60
} // namespace migraphx