Commit 121b3130 authored by Aditya Atluri's avatar Aditya Atluri
Browse files

fixed batch norm inference for cpu

parent c9b32c9c
...@@ -103,6 +103,24 @@ struct not_computable ...@@ -103,6 +103,24 @@ struct not_computable
} }
}; };
struct batch_norm_inference
{
double epsilon = 1.0e-6;
std::string name() const { return "batch_norm_inference"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
return inputs.front();
}
argument compute(context&, shape, std::vector<argument>) const
{
MIGRAPH_THROW("not computable");
}
};
struct convolution struct convolution
{ {
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
......
...@@ -20,46 +20,44 @@ T zero(const T&) ...@@ -20,46 +20,44 @@ T zero(const T&)
// cpu implemenataion of batch norm for inference // cpu implemenataion of batch norm for inference
// //
// inputs are: // inputs are:
// mini-batch mean, mini-batch variance, // args[0] -> input data buffer
// input data, output data, // args[1] -> mini batch mean
// total number of elements in input, // args[2] -> mini batch variance
// epsilon for denominator to normalize, // args[3] -> gamma
// gamma and beta // args[4] -> beta
//
// The equation to compute batch norm for inference is:
//
// output[i] = beta + gamma * (input[i] + mean) / sqrt(variance + epsilon)
// //
// the input data format should be nchw // the input data format should be nchw
// //
struct cpu_batch_norm_inference struct cpu_batch_norm_inference
{ {
batch_norm_inference op;
std::string name() const { return "cpu::batch_norm_inference"; } std::string name() const { return "cpu::batch_norm_inference"; }
argument compute(context&, std::vector<argument> args) const shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
{
// args[0] is output (y),
// args[1] is input (x),
// args[2] is number of input elements (m),
// args[3] is mini-batch mean (u),
// args[4] is mini-batch variance (s),
// args[5] is epsilon (e)
// args[6] is gamma (g)
// args[7] is beta (b)
auto output = args[0];
auto input = args[1];
auto num_input_elements = args[2];
auto mini_batch_mean = args[3];
auto mini_batch_variance = args[4];
auto epsilon = args[5];
auto gamma = args[6];
auto beta = args[7];
for(size_t i = 0; i < num_input_elements; i++) {
output[i] = gamma * \
((input[i] - mini_batch_mean) / (std::sqrt(mini_batch_variance + epsilon))) + \
beta;
}
return output; argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument output{output_shape};
double epsilon = op.epsilon;
auto input = args[0];
auto mini_batch_mean = args[1].at<float>();
auto mini_batch_variance = args[2].at<float>();
auto gamma = args[3].at<float>();
auto beta = args[4].at<float>();
visit_all(output, input) ([&](auto result, auto buffer) {
std::transform(buffer.begin(), buffer.end(), result.begin(), [&](auto x) {
return gamma * (x - mini_batch_mean) / std::sqrt(mini_batch_variance + epsilon) + beta;
});
});
return output;
} }
}; };
...@@ -517,6 +515,7 @@ struct cpu_apply ...@@ -517,6 +515,7 @@ struct cpu_apply
{ {
apply_map["convolution"] = extend_op<cpu_convolution, convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, convolution>();
apply_map["gemm"] = extend_op<cpu_gemm, gemm>(); apply_map["gemm"] = extend_op<cpu_gemm, gemm>();
apply_map["batch_norm_inference"] = extend_op<cpu_batch_norm_inference, batch_norm_inference>();
apply_map["reshape"] = extend_op<cpu_reshape, reshape>(); apply_map["reshape"] = extend_op<cpu_reshape, reshape>();
apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>();
apply_map["transpose"] = extend_op<cpu_transpose, transpose>(); apply_map["transpose"] = extend_op<cpu_transpose, transpose>();
......
...@@ -10,11 +10,18 @@ void batch_norm_inference_test() ...@@ -10,11 +10,18 @@ void batch_norm_inference_test()
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {4}}; migraph::shape s{migraph::shape::float_type, {4}};
auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}}; auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}});
auto y = p.add_literal(migraph::literal{s, {0, 0, 0, 0}}; auto gamma = p.add_literal(migraph::literal{s, {1}});
p.add_instruction(migraph::cpu_batch_norm_inference, y, x, 4, 0, 0.5, 0.5, 1, 0); auto beta = p.add_literal(migraph::literal{s, {0}});
auto mean = p.add_literal(migraph::literal{s, {0}});
auto variance = p.add_literal(migraph::literal{s, {1}});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, gamma, beta);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> result_vector(4);
result.visit([&](auto output) {result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = { 1 / (1 + 1.0e-6), 2 / (1 + 1.0e-6), 3 / (1 + 1.0e-6), 4 / (1 + 1.0e-6)};
EXPECT(test::verify_range(result_vector, gold));
} }
void exp_test() void exp_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment