"vscode:/vscode.git/clone" did not exist on "aa1844fc7f416a9ec31f33020a68cb7010e50e91"
Commit c3ec7238 authored by Aditya Atluri's avatar Aditya Atluri
Browse files

added per activation batch norm inference

parent 44513aca
......@@ -24,6 +24,14 @@ struct batch_norm_inference
std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
......
......@@ -56,6 +56,7 @@ struct cpu_batch_norm_inference
auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3];
if(op.bn_mode == batch_norm_inference::spatial) {
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
......@@ -66,6 +67,18 @@ struct cpu_batch_norm_inference
bias(c);
});
});
}
if(op.bn_mode == batch_norm_inference::per_activation) {
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width) (
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c, h, w) * (buffer(n, c, h, w) - mean(c, h, w)) / std::sqrt(variance(c, h, w) + epsilon) + bias(c, h, w);
});
});
}
return output;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment