fwd_conv_batchnorm_rewrite.cpp 3.01 KB
Newer Older
1
2
3
4
5
6
7
8
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>

namespace migraph {
9
inline namespace MIGRAPH_INLINE_NS {
10

11
12
void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
wsttiger's avatar
wsttiger committed
13
    for(auto ins : iterator_for(p))
14
    {
Paul's avatar
Paul committed
15
        if(ins->name() != "batch_norm_inference")
16
            continue;
Paul's avatar
Paul committed
17
        if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
Paul's avatar
Paul committed
18
               return arg->name() == "@literal";
Paul's avatar
Paul committed
19
           }))
20
21
            continue;

Paul's avatar
Paul committed
22
        auto conv_ins = ins->inputs()[0];
Paul's avatar
Paul committed
23
        if(conv_ins->name() != "convolution")
24
            continue;
Paul's avatar
Paul committed
25
        if(conv_ins->inputs()[1]->name() != "@literal")
26
27
            continue;
        // Get scale, bias, mean, variance from instruction_ref
Paul's avatar
Paul committed
28
29
30
31
        const auto& gamma    = ins->inputs()[1]->get_literal();
        const auto& bias     = ins->inputs()[2]->get_literal();
        const auto& mean     = ins->inputs()[3]->get_literal();
        const auto& variance = ins->inputs()[4]->get_literal();
32
        // Get epsilon
33
        auto bn_op   = any_cast<op::batch_norm_inference>(ins->get_operator());
34
35
        auto epsilon = bn_op.epsilon;
        // Get convolution weights
Paul's avatar
Paul committed
36
        const auto& weights = conv_ins->inputs()[1]->get_literal();
37
        // Get convolution op
38
        auto conv_op      = conv_ins->get_operator();
Paul's avatar
Paul committed
39
        auto weights_lens = weights.get_shape().lens();
Paul's avatar
Paul committed
40
        auto conv_lens    = conv_ins->get_shape().lens();
41
        argument new_weights{weights.get_shape()};
Paul's avatar
Paul committed
42
        argument new_bias{bias.get_shape()};
43
44
45
46
47
48
49
50
        visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
            [&](auto weights2,
                auto gamma2,
                auto bias2,
                auto mean2,
                auto variance2,
                auto new_weights2,
                auto new_bias2) {
Paul's avatar
Paul committed
51
                dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
52
                    [&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
Paul's avatar
Paul committed
53
54
                        new_weights2(k, c, h, w) =
                            gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
Paul's avatar
Paul committed
55
                    });
Paul's avatar
Paul committed
56
                dfor(new_bias.get_shape().elements())([&](std::size_t c) {
Scott Thornton's avatar
Scott Thornton committed
57
58
                    new_bias2(c) =
                        bias2(c) - (gamma2(c) * mean2(c) / std::sqrt(variance2(c) + epsilon));
Paul's avatar
Paul committed
59
                });
60
61
62
            });
        // Replace convolution instruction with updated weights
        auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
Paul's avatar
Paul committed
63
        auto l_bias    = p.add_literal({new_bias.get_shape(), new_bias.data()});
Paul's avatar
Paul committed
64
        auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
Scott Thornton's avatar
Scott Thornton committed
65
        auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias);
66
        p.replace_instruction(ins, op::add{}, {c, b});
67
68
    }
}
69

Shucai Xiao's avatar
Shucai Xiao committed
70
} // namespace MIGRAPH_INLINE_NS
71
} // namespace migraph