Commit cb7db7a9 authored by wsttiger's avatar wsttiger
Browse files

Batch norm tests for MIOpen and CPU added and working

parent e384b83f
...@@ -99,6 +99,11 @@ struct instruction ...@@ -99,6 +99,11 @@ struct instruction
}); });
} }
shape get_shape() const
{
return result;
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
......
...@@ -388,13 +388,23 @@ struct miopen_apply ...@@ -388,13 +388,23 @@ struct miopen_apply
{ {
auto&& op = any_cast<batch_norm_inference>(ins->op); auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, shape old_shape = ins->arguments.at(1)->get_shape();
std::vector<int64_t> new_shape{1,static_cast<int64_t>(old_shape.elements()),1,1};
auto arg1 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(1));
auto arg2 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(2));
auto arg3 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(3));
auto arg4 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(4));
prog->replace_instruction(ins,
miopen_batch_norm_inference{op}, miopen_batch_norm_inference{op},
ins->arguments.at(0), ins->arguments.at(0),
ins->arguments.at(1), arg1,
ins->arguments.at(2), arg2,
ins->arguments.at(3), arg3,
ins->arguments.at(4), arg4,
output); output);
} }
}; };
......
...@@ -227,7 +227,7 @@ struct test_batchnorm_inference ...@@ -227,7 +227,7 @@ struct test_batchnorm_inference
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {1,channels,1,1}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars); auto mean = p.add_parameter("mean", vars);
auto variance = p.add_parameter("variance", vars); auto variance = p.add_parameter("variance", vars);
...@@ -260,7 +260,7 @@ void batch_norm_inference_test() ...@@ -260,7 +260,7 @@ void batch_norm_inference_test()
const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val; const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {1,channels,1,1}};
std::vector<float> x_data(width * height * channels * batches); std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels); std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels); std::vector<float> bias_data(channels);
...@@ -281,7 +281,10 @@ void batch_norm_inference_test() ...@@ -281,7 +281,10 @@ void batch_norm_inference_test()
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
p.compile(migraph::gpu::target{}); p.compile(migraph::gpu::target{});
auto result = p.eval({});
migraph::program::parameter_map m;
m["output"] = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
auto result = migraph::gpu::from_gpu(p.eval(m));
std::vector<float> result_vector(width * height * channels * batches); std::vector<float> result_vector(width * height * channels * batches);
std::vector<float> gold(width * height * channels * batches); std::vector<float> gold(width * height * channels * batches);
...@@ -300,6 +303,6 @@ int main() ...@@ -300,6 +303,6 @@ int main()
verify_program<test_gemm>(); verify_program<test_gemm>();
verify_program<test_contiguous>(); verify_program<test_contiguous>();
verify_program<test_transpose>(); verify_program<test_transpose>();
// verify_program<test_batchnorm_inference>(); verify_program<test_batchnorm_inference>();
// batch_norm_inference_test(); batch_norm_inference_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment