Commit 6a2d72d7 authored by Paul's avatar Paul
Browse files

Formatting

parent e17c0bec
...@@ -35,7 +35,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -35,7 +35,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Get convolution op // Get convolution op
auto conv_op = conv_ins->op; auto conv_op = conv_ins->op;
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()}; argument new_bias{bias.get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)( visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
...@@ -51,11 +51,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -51,11 +51,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
new_weights2(k, c, h, w) = new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w); gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
}); });
dfor(new_bias.get_shape().elements())( dfor(new_bias.get_shape().elements())([&](std::size_t c) {
[&](std::size_t c) { new_bias2(c) = bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
new_bias2(c) = });
bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
});
}); });
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment