Unverified Commit 6bb6b72e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #27 from ROCmSoftwarePlatform/per-activation-cpu-bn-infer

added per activation batch norm inference
parents 44513aca ddfd8ad3
...@@ -24,6 +24,14 @@ struct batch_norm_inference ...@@ -24,6 +24,14 @@ struct batch_norm_inference
std::string name() const { return "batch_norm_inference"; } std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
......
...@@ -56,6 +56,8 @@ struct cpu_batch_norm_inference ...@@ -56,6 +56,8 @@ struct cpu_batch_norm_inference
auto image_height = output_shape.lens()[2]; auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3]; auto image_width = output_shape.lens()[3];
if(op.bn_mode == batch_norm_inference::spatial)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
...@@ -66,6 +68,22 @@ struct cpu_batch_norm_inference ...@@ -66,6 +68,22 @@ struct cpu_batch_norm_inference
bias(c); bias(c);
}); });
}); });
}
if(op.bn_mode == batch_norm_inference::per_activation)
{
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c, h, w) *
(buffer(n, c, h, w) - mean(c, h, w)) /
std::sqrt(variance(c, h, w) + epsilon) +
bias(c, h, w);
});
});
}
return output; return output;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment