"vscode:/vscode.git/clone" did not exist on "bbd073e417bca3d7ee08fbba3feeab6c686d9f19"
Commit c8b86e03 authored by Aditya Atluri's avatar Aditya Atluri
Browse files

fixed cpu test

parent 8ae3ffea
...@@ -24,11 +24,11 @@ T zero(const T&) ...@@ -24,11 +24,11 @@ T zero(const T&)
// args[1] -> mini batch mean // args[1] -> mini batch mean
// args[2] -> mini batch variance // args[2] -> mini batch variance
// args[3] -> gamma // args[3] -> gamma
// args[4] -> beta // args[4] -> bias
// //
// The equation to compute batch norm for inference is: // The equation to compute batch norm for inference is:
// //
// output[i] = beta + gamma * (input[i] + mean) / sqrt(variance + epsilon) // output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
// //
// the input data format should be nchw // the input data format should be nchw
// //
...@@ -46,16 +46,30 @@ struct cpu_batch_norm_inference ...@@ -46,16 +46,30 @@ struct cpu_batch_norm_inference
double epsilon = op.epsilon; double epsilon = op.epsilon;
auto input = args[0]; auto input = args[0];
auto mini_batch_mean = args[1].at<float>(); auto mini_batch_mean = args[1];
auto mini_batch_variance = args[2].at<float>(); auto mini_batch_variance = args[2];
auto gamma = args[3].at<float>(); auto gamma = args[3];
auto beta = args[4].at<float>(); auto bias = args[4];
visit_all(output, input)([&](auto result, auto buffer) { auto num_batch = output_shape.lens()[0];
std::transform(buffer.begin(), buffer.end(), result.begin(), [&](auto x) { auto num_channels = output_shape.lens()[1];
return gamma * (x - mini_batch_mean) / std::sqrt(mini_batch_variance + epsilon) + auto image_height = output_shape.lens()[2];
beta; auto image_width = output_shape.lens()[3];
});
visit_all(output, input, mini_batch_mean, mini_batch_variance, gamma, bias)([&](auto result, auto buffer, auto _mean, auto _variance, auto _gamma, auto _bias) {
for(size_t n = 0; n < num_batch; n++) {
size_t stride_n = n * num_channels * image_height * image_width;
for(size_t c = 0; c < num_channels; c++) {
size_t stride_c = c * image_height * image_width;
for(size_t h = 0; h < image_height; h++) {
size_t stride_h = h * image_width;
for(size_t w = 0; w < image_width; w++) {
size_t index = w + stride_h + stride_c + stride_n;
result[index] = _gamma[c] * (buffer[index] - _mean[c]) / std::sqrt(_variance[c] + epsilon) + _bias[c];
}
}
}
}
}); });
return output; return output;
......
...@@ -9,19 +9,39 @@ ...@@ -9,19 +9,39 @@
void batch_norm_inference_test() void batch_norm_inference_test()
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {4}}; const size_t width = 2, height = 2, channels = 4, batches = 2;
auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}}); const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f, bias_val = 1.0f;
auto gamma = p.add_literal(migraph::literal{s, {1}}); const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
auto beta = p.add_literal(migraph::literal{s, {0}});
auto mean = p.add_literal(migraph::literal{s, {0}}); migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
auto variance = p.add_literal(migraph::literal{s, {1}}); migraph::shape vars{migraph::shape::float_type, {channels}};
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, gamma, beta); std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels);
std::vector<float> mean_data(channels);
std::vector<float> variance_data(channels);
std::fill(x_data.begin(), x_data.end(), x_val);
std::fill(mean_data.begin(), mean_data.end(), mean_val);
std::fill(variance_data.begin(), variance_data.end(), variance_val);
std::fill(scale_data.begin(), scale_data.end(), scale_val);
std::fill(bias_data.begin(), bias_data.end(), bias_val);
auto x = p.add_literal(migraph::literal{s, x_data});
auto scale = p.add_literal(migraph::literal{vars, scale_data});
auto bias = p.add_literal(migraph::literal{vars, bias_data});
auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> result_vector(4);
std::vector<float> result_vector(width * height * channels * batches);
std::vector<float> gold(width * height * channels * batches);
std::fill(gold.begin(), gold.end(), output_val);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1 / (1 + 1.0e-6), 2 / (1 + 1.0e-6), 3 / (1 + 1.0e-6), 4 / (1 + 1.0e-6)};
EXPECT(test::verify_range(result_vector, gold)); EXPECT(test::verify_range(result_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment