Commit 3987066f authored by Paul's avatar Paul
Browse files

Formatting

parent b889693c
......@@ -34,30 +34,21 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
argument a{s};
argument b{s};
visit_all(gamma, bias, mean, variance, a, b)(
[&](auto gamma2,
auto bias2,
auto mean2,
auto variance2,
auto a2,
auto b2) {
dfor(a.get_shape().elements())(
[&](std::size_t c) {
a2[c] =
gamma2[c] / std::sqrt(variance2[c] + epsilon);
[&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
dfor(a.get_shape().elements())(
[&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
dfor(b.get_shape().elements())([&](std::size_t c) {
b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
});
dfor(b.get_shape().elements())([&](std::size_t c) {
b2[c] =
bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
});
});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment