Unverified Commit 61cbe923 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Cpu batchnorm (#562)



* change the batchnorm cpu implementation to support multiple input dimensions

* clang format

* add unit tests for cpu batch_norm nd implementation

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 742b4b82
......@@ -43,8 +43,7 @@ struct batch_norm_inference
{
check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements(
inputs.front().lens()[1]);
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape();
return inputs.front();
}
};
......
......@@ -86,39 +86,35 @@ struct cpu_batch_norm_inference
auto mini_batch_mean = args[3];
auto mini_batch_variance = args[4];
auto num_batch = output_shape.lens()[0];
auto num_channels = output_shape.lens()[1];
auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3];
if(op.bn_mode == op::batch_norm_inference::spatial)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance[c] + epsilon) > 0);
result(n, c, h, w) = gamma[c] * (buffer(n, c, h, w) - mean[c]) /
std::sqrt(variance[c] + epsilon) +
bias[c];
});
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto c = idx[1];
assert((variance[c] + epsilon) > 0);
result[i] =
gamma[c] * (buffer[i] - mean[c]) / std::sqrt(variance[c] + epsilon) +
bias[c];
});
});
}
if(op.bn_mode == op::batch_norm_inference::per_activation)
{
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
par_dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
assert((variance(c, h, w) + epsilon) > 0);
result(n, c, h, w) = gamma(c, h, w) *
(buffer(n, c, h, w) - mean(c, h, w)) /
std::sqrt(variance(c, h, w) + epsilon) +
bias(c, h, w);
});
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
idx[0] = 0;
auto index = output_shape.index(idx);
assert((variance[index] + epsilon) > 0);
result[i] = gamma[index] * (buffer[i] - mean[index]) /
std::sqrt(variance[index] + epsilon) +
bias[index];
});
});
}
......
......@@ -416,7 +416,6 @@ TEST_CASE(avgpool_test)
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::cout << "result = " << result << std::endl;
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.6321,
......@@ -714,6 +713,109 @@ TEST_CASE(im2col_3x3_with_padding_test)
EXPECT(migraphx::verify_range(results_vector, correct));
}
TEST_CASE(batch_norm_1d_test)
{
migraphx::program p;
migraphx::shape x_shape{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape c_shape(migraphx::shape::float_type, {3});
std::vector<float> x_data = {0.7253, -0.6356, 0.4606, -0.8689, -1.1932, 0.4538,
-1.0018, -0.365, -0.214, -0.9553, -0.7672, 0.2331,
-0.8416, -0.6142, 0.0814, 0.2498, -0.6706, 1.4872,
0.5112, -1.5212, -0.9126, 0.0735, 1.085, -0.3417};
std::vector<float> scale_data = {1.1, 1.2, 1.3};
std::vector<float> bias_data = {0.1, 0.2, 0.3};
std::vector<float> mean_data = {-0.1804, -0.2875, -0.2249};
std::vector<float> variance_data = {2.7914, 7.3424, 3.3287};
auto x = p.add_literal(migraphx::literal{x_shape, x_data});
auto scale = p.add_literal(migraphx::literal{c_shape, scale_data});
auto bias = p.add_literal(migraphx::literal{c_shape, bias_data});
auto mean = p.add_literal(migraphx::literal{c_shape, mean_data});
auto variance = p.add_literal(migraphx::literal{c_shape, variance_data});
p.add_instruction(migraphx::op::batch_norm_inference{1e-5}, x, scale, bias, mean, variance);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.696301, -0.199697, 0.522026, -0.353299, -0.201094, 0.528289,
-0.116332, 0.165679, 0.307767, -0.220435, -0.086407, 0.62634,
-0.335325, -0.185608, 0.272366, 0.383238, 0.0303421, 0.985936,
0.553709, -0.346351, -0.190009, 0.51262, 1.23335, 0.216776};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(batch_norm_1d_per_actv_test)
{
migraphx::program p;
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 4}};
migraphx::shape c_shape(migraphx::shape::float_type, {2, 4});
std::vector<float> x_data = {0.3547,
0.477,
-1.8575,
0.663,
-0.1881,
-0.5113,
-0.1803,
-0.5915,
-0.1552,
0.9821,
1.827,
0.0558,
-0.0417,
-1.0693,
1.9948,
-0.7448};
std::vector<float> scale_data = {
-0.3181, -0.3885, 1.655, 0.0704, -0.2565, -1.1761, -0.3751, 0.1057};
std::vector<float> bias_data = {
-1.2118, -2.1156, 0.0046, -0.1341, -0.2724, -1.0718, 0.5535, -0.889};
std::vector<float> mean_data = {
0.0997, 0.7295, -0.0153, 0.3594, -0.1149, -0.7903, 0.9073, -0.6681};
std::vector<float> variance_data = {
0.13, 0.1276, 6.7878, 0.1843, 0.0107, 0.1556, 2.3655, 0.0117};
auto x = p.add_literal(migraphx::literal{x_shape, x_data});
auto scale = p.add_literal(migraphx::literal{c_shape, scale_data});
auto bias = p.add_literal(migraphx::literal{c_shape, bias_data});
auto mean = p.add_literal(migraphx::literal{c_shape, mean_data});
auto variance = p.add_literal(migraphx::literal{c_shape, variance_data});
p.add_instruction(
migraphx::op::batch_norm_inference{
1e-6, 0.9, migraphx::op::batch_norm_inference::per_activation},
x,
scale,
bias,
mean,
variance);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.43677,
-1.84098,
-1.16563,
-0.0843136,
-0.090896,
-1.90364,
0.81875,
-0.81415,
-0.986915,
-2.39032,
1.17489,
-0.183886,
-0.453904,
-0.239955,
0.288275,
-0.963948};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(batch_norm_inference_test)
{
migraphx::program p;
......@@ -760,6 +862,41 @@ TEST_CASE(batch_norm_inference_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(batch_norm_3d_test)
{
migraphx::program p;
migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2, 2, 2}};
migraphx::shape c_shape(migraphx::shape::float_type, {2});
std::vector<float> x_data = {-1.0833, 1.9681, 1.2075, -0.723, -0.4076, -0.8738, 0.5853,
-0.5357, 1.734, 0.7904, 0.6953, -0.468, -0.425, 0.6895,
0.0096, 0.4205, -0.1749, 1.2821, 2.1453, -0.8538, 1.0687,
0.0906, 0.0714, -1.3079, -0.6376, 1.3023, 0.945, 0.0927,
-0.7421, -1.4341, -1.0309, 1.5153};
std::vector<float> scale_data = {1.1, 1.3};
std::vector<float> bias_data = {0.1, 0.2};
std::vector<float> mean_data = {0.1537, 0.2161};
std::vector<float> variance_data = {18.0805, 13.3906};
auto x = p.add_literal(migraphx::literal{x_shape, x_data});
auto scale = p.add_literal(migraphx::literal{c_shape, scale_data});
auto bias = p.add_literal(migraphx::literal{c_shape, bias_data});
auto mean = p.add_literal(migraphx::literal{c_shape, mean_data});
auto variance = p.add_literal(migraphx::literal{c_shape, variance_data});
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
-0.220005, 0.569376, 0.372612, -0.126798, -0.0452053, -0.165809, 0.211653, -0.0783441,
0.739245, 0.404024, 0.370239, -0.0430317, -0.0277556, 0.368179, 0.126639, 0.272615,
0.0149929, 0.391911, 0.615216, -0.160635, 0.336706, 0.0836764, 0.0787094, -0.278108,
-0.103283, 0.585881, 0.458947, 0.156161, -0.140408, -0.386246, -0.243006, 0.661551};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(im2col_3x3_with_channels_identity_test)
{
std::size_t f[2] = {3, 3};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment