Commit 5b709a1c authored by Paul's avatar Paul
Browse files

Fix bias shape

parent 1be2883d
......@@ -34,12 +34,10 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
const auto& weights = conv_ins->arguments[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->op;
auto out_channels = weights.get_shape().lens()[0];
auto in_channels = weights.get_shape().lens()[1];
auto height = weights.get_shape().lens()[2];
auto width = weights.get_shape().lens()[3];
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()};
argument new_bias{conv_ins->get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2,
auto gamma2,
......@@ -48,17 +46,20 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(out_channels, in_channels, height, width)(
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
new_bias2(k, c, h, w) =
bias2(k) - (mean2(k) / std::sqrt(variance2(k) + epsilon));
});
dfor(conv_lens[0], conv_lens[1], conv_lens[2], conv_lens[3])(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
new_bias2(n, c, h, w) =
bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
});
});
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
p.replace_instruction(ins, add{}, {c, l_bias});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment