Commit 5ff89419 authored by wsttiger's avatar wsttiger
Browse files

Formatting

parent ae852218
...@@ -8,11 +8,13 @@ ...@@ -8,11 +8,13 @@
namespace migraph { namespace migraph {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
for (auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{
if(ins->op.name() == "batch_norm_inference")
{ {
if (ins->op.name() == "batch_norm_inference") {
auto ins_prev = ins->arguments[0]; auto ins_prev = ins->arguments[0];
if (ins_prev->op.name() == "convolution") { if(ins_prev->op.name() == "convolution")
{
// Get scale, bias, mean, variance from instruction_ref // Get scale, bias, mean, variance from instruction_ref
auto gamma = ins->arguments[1]->lit; auto gamma = ins->arguments[1]->lit;
auto bias = ins->arguments[2]->lit; auto bias = ins->arguments[2]->lit;
...@@ -32,19 +34,27 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -32,19 +34,27 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()}; argument new_bias{bias.get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2, auto gamma2, auto bias2, auto mean2, [&](auto weights2,
auto variance2, auto new_weights2, auto new_bias2) { auto gamma2,
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(out_channels, in_channels, height, width)( dfor(out_channels, in_channels, height, width)(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) = gamma2(k) / std::sqrt(variance2(k)+epsilon) * weights2(k, c, h, w); new_weights2(k, c, h, w) = gamma2(k) /
new_bias2(k, c, h, w) = bias2(k) - (mean2(k) / std::sqrt(variance2(k)+epsilon)); std::sqrt(variance2(k) + epsilon) *
weights2(k, c, h, w);
new_bias2(k, c, h, w) =
bias2(k) - (mean2(k) / std::sqrt(variance2(k) + epsilon));
}); });
}); });
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(ins_prev, conv_op, auto c =
{ins_prev->arguments[0], l_weights}); p.replace_instruction(ins_prev, conv_op, {ins_prev->arguments[0], l_weights});
p.replace_instruction(ins, add{}, {c, l_bias}); p.replace_instruction(ins, add{}, {c, l_bias});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment