"tests/vscode:/vscode.git/clone" did not exist on "bdacf2917901ac559ce1fd8593953c67f828ac2c"
rewrite_batchnorm.cpp 2.29 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraphx/rewrite_batchnorm.hpp>
Paul's avatar
Paul committed
2
3
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
5
6
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/op/mul.hpp>
Paul's avatar
Paul committed
8
#include <migraphx/iterator_for.hpp>
9
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/dfor.hpp>
11

Paul's avatar
Paul committed
12
namespace migraphx {
Paul's avatar
Paul committed
13
inline namespace MIGRAPHX_INLINE_NS {
14

Paul's avatar
Paul committed
15
void rewrite_batchnorm::apply(program& p) const
16
{
wsttiger's avatar
wsttiger committed
17
    for(auto ins : iterator_for(p))
18
    {
Paul's avatar
Paul committed
19
        if(ins->name() != "batch_norm_inference")
20
            continue;
21
        // Get scale, bias, mean, variance from inputs
Paul's avatar
Paul committed
22
23
24
25
        auto gamma    = ins->inputs()[1]->eval();
        auto bias     = ins->inputs()[2]->eval();
        auto mean     = ins->inputs()[3]->eval();
        auto variance = ins->inputs()[4]->eval();
26
        if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
27
28
            continue;

Paul's avatar
Paul committed
29
        auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
30
        // Get epsilon
31
        auto bn_op   = any_cast<op::batch_norm_inference>(ins->get_operator());
32
        auto epsilon = bn_op.epsilon;
Paul's avatar
Paul committed
33
34
35
36

        argument a{s};
        argument b{s};
        visit_all(gamma, bias, mean, variance, a, b)(
Paul's avatar
Paul committed
37
38
39
40
41
            [&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
                dfor(a.get_shape().elements())(
                    [&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
                dfor(b.get_shape().elements())([&](std::size_t c) {
                    b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
Paul's avatar
Paul committed
42
                });
43
            });
Paul's avatar
Paul committed
44

Paul's avatar
Paul committed
45
46
        auto broadcast   = op::broadcast{1, ins->get_shape().lens()};
        auto a_ins       = p.add_literal({a.get_shape(), a.data()});
Paul's avatar
Paul committed
47
        auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
Paul's avatar
Paul committed
48
49
        auto mul         = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
        auto b_ins       = p.add_literal({b.get_shape(), b.data()});
Paul's avatar
Paul committed
50
        auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
Paul's avatar
Paul committed
51
        auto add         = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
Paul's avatar
Paul committed
52
        p.replace_instruction(ins, add);
53
54
    }
}
55

Paul's avatar
Paul committed
56
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
57
} // namespace migraphx