Commit 6a2d72d7 authored by Paul's avatar Paul
Browse files

Formatting

parent e17c0bec
......@@ -35,7 +35,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Get convolution op
auto conv_op = conv_ins->op;
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
......@@ -51,11 +51,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())(
[&](std::size_t c) {
new_bias2(c) =
bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2(c) = bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
});
});
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment