Commit 1be2883d authored by Paul's avatar Paul
Browse files

Formatting

parent aa0b6230
...@@ -12,7 +12,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -12,7 +12,9 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
if(ins->op.name() != "batch_norm_inference") if(ins->op.name() != "batch_norm_inference")
continue; continue;
if(not std::all_of(ins->arguments.begin()+1, ins->arguments.end(), [](auto arg) { return arg->op.name() == "@literal"; })) if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) {
return arg->op.name() == "@literal";
}))
continue; continue;
auto conv_ins = ins->arguments[0]; auto conv_ins = ins->arguments[0];
...@@ -48,9 +50,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -48,9 +50,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto new_bias2) { auto new_bias2) {
dfor(out_channels, in_channels, height, width)( dfor(out_channels, in_channels, height, width)(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) { [&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) = gamma2(k) / new_weights2(k, c, h, w) =
std::sqrt(variance2(k) + epsilon) * gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
weights2(k, c, h, w);
new_bias2(k, c, h, w) = new_bias2(k, c, h, w) =
bias2(k) - (mean2(k) / std::sqrt(variance2(k) + epsilon)); bias2(k) - (mean2(k) / std::sqrt(variance2(k) + epsilon));
}); });
...@@ -58,8 +59,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -58,8 +59,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()});
auto c = auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
p.replace_instruction(ins, add{}, {c, l_bias}); p.replace_instruction(ins, add{}, {c, l_bias});
} }
} }
......
...@@ -115,10 +115,10 @@ struct instruction ...@@ -115,10 +115,10 @@ struct instruction
} }
shape get_shape() const { return result; } shape get_shape() const { return result; }
const literal& get_literal() const const literal& get_literal() const
{ {
assert(op.name() == "@literal"); assert(op.name() == "@literal");
return lit; return lit;
} }
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment