"...src/static/style/experimentManagement/experiment.scss" did not exist on "c444e8624c8eea6948ef902de0aa9a7b8079841f"
fwd_conv_batchnorm_rewrite.cpp 2.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
4
5
6
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/iterator_for.hpp>
8
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/dfor.hpp>
10

Paul's avatar
Paul committed
11
namespace migraphx {
Paul's avatar
Paul committed
12
inline namespace MIGRAPHX_INLINE_NS {
13

14
15
void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
wsttiger's avatar
wsttiger committed
16
    for(auto ins : iterator_for(p))
17
    {
Paul's avatar
Paul committed
18
        if(ins->name() != "batch_norm_inference")
19
            continue;
20
        // Get scale, bias, mean, variance from inputs
Paul's avatar
Paul committed
21
22
23
24
        auto gamma    = ins->inputs()[1]->eval();
        auto bias     = ins->inputs()[2]->eval();
        auto mean     = ins->inputs()[3]->eval();
        auto variance = ins->inputs()[4]->eval();
25
        if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
26
27
            continue;

Paul's avatar
Paul committed
28
        auto conv_ins = ins->inputs()[0];
Paul's avatar
Paul committed
29
        if(conv_ins->name() != "convolution")
30
            continue;
31
        // Get convolution weights
Paul's avatar
Paul committed
32
        auto weights = conv_ins->inputs()[1]->eval();
33
        if(weights.empty())
34
35
            continue;
        // Get epsilon
36
        auto bn_op   = any_cast<op::batch_norm_inference>(ins->get_operator());
37
38
        auto epsilon = bn_op.epsilon;
        // Get convolution op
39
        auto conv_op      = conv_ins->get_operator();
Paul's avatar
Paul committed
40
        auto weights_lens = weights.get_shape().lens();
Paul's avatar
Paul committed
41
        auto conv_lens    = conv_ins->get_shape().lens();
42
        argument new_weights{weights.get_shape()};
43
        argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}};
44
45
46
47
48
49
50
51
        visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
            [&](auto weights2,
                auto gamma2,
                auto bias2,
                auto mean2,
                auto variance2,
                auto new_weights2,
                auto new_bias2) {
Paul's avatar
Paul committed
52
                dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
53
                    [&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
Paul's avatar
Paul committed
54
                        new_weights2(k, c, h, w) =
55
                            gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w);
Paul's avatar
Paul committed
56
                    });
Paul's avatar
Paul committed
57
                dfor(new_bias.get_shape().elements())([&](std::size_t c) {
58
59
                    new_bias2[c] =
                        bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
Paul's avatar
Paul committed
60
                });
61
62
63
            });
        // Replace convolution instruction with updated weights
        auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
Paul's avatar
Paul committed
64
        auto l_bias    = p.add_literal({new_bias.get_shape(), new_bias.data()});
Paul's avatar
Paul committed
65
        auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
Scott Thornton's avatar
Scott Thornton committed
66
        auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape()}, l_bias);
67
        p.replace_instruction(ins, op::add{}, {c, b});
68
69
    }
}
70

Paul's avatar
Paul committed
71
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
72
} // namespace migraphx