Unverified Commit 6bb6b72e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #27 from ROCmSoftwarePlatform/per-activation-cpu-bn-infer

added per activation batch norm inference
parents 44513aca ddfd8ad3
...@@ -24,6 +24,14 @@ struct batch_norm_inference ...@@ -24,6 +24,14 @@ struct batch_norm_inference
std::string name() const { return "batch_norm_inference"; } std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
......
...@@ -56,16 +56,34 @@ struct cpu_batch_norm_inference ...@@ -56,16 +56,34 @@ struct cpu_batch_norm_inference
auto image_height = output_shape.lens()[2]; auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3]; auto image_width = output_shape.lens()[3];
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( if(op.bn_mode == batch_norm_inference::spatial)
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { {
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
dfor(num_batch, num_channels, image_height, image_width)( [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) / dfor(num_batch, num_channels, image_height, image_width)(
std::sqrt(variance(c) + epsilon) + [&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
bias(c); result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
}); std::sqrt(variance(c) + epsilon) +
}); bias(c);
});
});
}
if(op.bn_mode == batch_norm_inference::per_activation)
{
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c, h, w) *
(buffer(n, c, h, w) - mean(c, h, w)) /
std::sqrt(variance(c, h, w) + epsilon) +
bias(c, h, w);
});
});
}
return output; return output;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment