Commit 0df035a0 authored by Aditya Atluri's avatar Aditya Atluri
Browse files

formatted source for previous commit

parent c8b86e03
...@@ -51,26 +51,33 @@ struct cpu_batch_norm_inference ...@@ -51,26 +51,33 @@ struct cpu_batch_norm_inference
auto gamma = args[3]; auto gamma = args[3];
auto bias = args[4]; auto bias = args[4];
auto num_batch = output_shape.lens()[0]; auto num_batch = output_shape.lens()[0];
auto num_channels = output_shape.lens()[1]; auto num_channels = output_shape.lens()[1];
auto image_height = output_shape.lens()[2]; auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3]; auto image_width = output_shape.lens()[3];
visit_all(output, input, mini_batch_mean, mini_batch_variance, gamma, bias)([&](auto result, auto buffer, auto _mean, auto _variance, auto _gamma, auto _bias) { visit_all(output, input, mini_batch_mean, mini_batch_variance, gamma, bias)(
for(size_t n = 0; n < num_batch; n++) { [&](auto result, auto buffer, auto _mean, auto _variance, auto _gamma, auto _bias) {
size_t stride_n = n * num_channels * image_height * image_width; for(size_t n = 0; n < num_batch; n++)
for(size_t c = 0; c < num_channels; c++) { {
size_t stride_c = c * image_height * image_width; size_t stride_n = n * num_channels * image_height * image_width;
for(size_t h = 0; h < image_height; h++) { for(size_t c = 0; c < num_channels; c++)
size_t stride_h = h * image_width; {
for(size_t w = 0; w < image_width; w++) { size_t stride_c = c * image_height * image_width;
size_t index = w + stride_h + stride_c + stride_n; for(size_t h = 0; h < image_height; h++)
result[index] = _gamma[c] * (buffer[index] - _mean[c]) / std::sqrt(_variance[c] + epsilon) + _bias[c]; {
size_t stride_h = h * image_width;
for(size_t w = 0; w < image_width; w++)
{
size_t index = w + stride_h + stride_c + stride_n;
result[index] = _gamma[c] * (buffer[index] - _mean[c]) /
std::sqrt(_variance[c] + epsilon) +
_bias[c];
}
} }
} }
} }
} });
});
return output; return output;
} }
......
...@@ -10,7 +10,8 @@ void batch_norm_inference_test() ...@@ -10,7 +10,8 @@ void batch_norm_inference_test()
{ {
migraph::program p; migraph::program p;
const size_t width = 2, height = 2, channels = 4, batches = 2; const size_t width = 2, height = 2, channels = 4, batches = 2;
const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f, bias_val = 1.0f; const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f,
bias_val = 1.0f;
const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val; const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment