Commit b889693c authored by Paul's avatar Paul
Browse files

Convert batchnorm to multiply and add

parent d5ade1e7
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
...@@ -25,46 +26,39 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -25,46 +26,39 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); })) if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue; continue;
auto conv_ins = ins->inputs()[0]; auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
if(conv_ins->name() != "convolution")
continue;
// Get convolution weights
auto weights = conv_ins->inputs()[1]->eval();
if(weights.empty())
continue;
// Get epsilon // Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator()); auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution op
auto conv_op = conv_ins->get_operator(); argument a{s};
auto weights_lens = weights.get_shape().lens(); argument b{s};
auto conv_lens = conv_ins->get_shape().lens(); visit_all(gamma, bias, mean, variance, a, b)(
argument new_weights{weights.get_shape()}; [&](auto gamma2,
argument new_bias{{bias.get_shape().type(), {bias.get_shape().elements()}}};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2,
auto gamma2,
auto bias2, auto bias2,
auto mean2, auto mean2,
auto variance2, auto variance2,
auto new_weights2, auto a2,
auto new_bias2) { auto b2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])( dfor(a.get_shape().elements())(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t c) {
new_weights2(k, c, h, w) = a2[c] =
gamma2[k] / std::sqrt(variance2[k] + epsilon) * weights2(k, c, h, w); gamma2[c] / std::sqrt(variance2[c] + epsilon);
}); });
dfor(new_bias.get_shape().elements())([&](std::size_t c) { dfor(b.get_shape().elements())([&](std::size_t c) {
new_bias2[c] = b2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon)); bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights}); auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto b = p.insert_instruction(ins, op::broadcast{1, c->get_shape().lens()}, l_bias); auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
p.replace_instruction(ins, op::add{}, {c, b}); auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
} }
} }
......
...@@ -96,7 +96,7 @@ TEST_CASE(non_literal) ...@@ -96,7 +96,7 @@ TEST_CASE(non_literal)
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(any_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
} }
TEST_CASE(as_literal) TEST_CASE(as_literal)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment