rewrite_batchnorm.cpp 2.33 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/rewrite_batchnorm.hpp>
Paul's avatar
Paul committed
2
3
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
5
6
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/op/mul.hpp>
Paul's avatar
Paul committed
8
#include <migraphx/iterator_for.hpp>
9
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/dfor.hpp>
11

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
14

Paul's avatar
Paul committed
15
void rewrite_batchnorm::apply(program& p) const
16
{
wsttiger's avatar
wsttiger committed
17
    for(auto ins : iterator_for(p))
18
    {
Paul's avatar
Paul committed
19
        if(ins->name() != "batch_norm_inference")
20
            continue;
21
        // Get scale, bias, mean, variance from inputs
Paul's avatar
Paul committed
22
23
24
25
        auto gamma    = ins->inputs()[1]->eval();
        auto bias     = ins->inputs()[2]->eval();
        auto mean     = ins->inputs()[3]->eval();
        auto variance = ins->inputs()[4]->eval();
26
        if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
27
28
            continue;

Shucai Xiao's avatar
Shucai Xiao committed
29
30
        std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens();
        shape s{ins->get_shape().type(), lens};
31
        // Get epsilon
32
        auto bn_op   = any_cast<op::batch_norm_inference>(ins->get_operator());
33
        auto epsilon = bn_op.epsilon;
Paul's avatar
Paul committed
34
35
36
37

        argument a{s};
        argument b{s};
        visit_all(gamma, bias, mean, variance, a, b)(
Paul's avatar
Paul committed
38
39
40
41
42
            [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
                dfor(a.get_shape().elements())(
                    [&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
                dfor(b.get_shape().elements())([&](std::size_t c) {
                    b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
Paul's avatar
Paul committed
43
                });
44
            });
Paul's avatar
Paul committed
45

Paul's avatar
Paul committed
46
47
        auto broadcast   = op::broadcast{1, ins->get_shape().lens()};
        auto a_ins       = p.add_literal({a.get_shape(), a.data()});
Paul's avatar
Paul committed
48
        auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
Paul's avatar
Paul committed
49
50
        auto mul         = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
        auto b_ins       = p.add_literal({b.get_shape(), b.data()});
Paul's avatar
Paul committed
51
        auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
Paul's avatar
Paul committed
52
        auto add         = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
Paul's avatar
Paul committed
53
        p.replace_instruction(ins, add);
54
55
    }
}
56

Paul's avatar
Paul committed
57
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
58
} // namespace migraphx