"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "52ed1fc3c9b7f5138c5eabb2d04bd1aa5365fd7f"
Commit e384b83f authored by wsttiger's avatar wsttiger
Browse files

Added test but not working

parent 8ce572f8
......@@ -215,6 +215,82 @@ struct test_transpose
}
};
struct test_batchnorm_inference
{
const size_t width = 3;
const size_t height = 3;
const size_t channels = 3;
const size_t batches = 4;
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars);
auto variance = p.add_parameter("variance", vars);
auto scale = p.add_parameter("scale", vars);
auto bias = p.add_parameter("bias", vars);
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
m["x"] = migraph::generate_argument(s);
m["mean"] = migraph::generate_argument(vars);
m["variance"] = migraph::generate_argument(vars);
m["scale"] = migraph::generate_argument(vars);
m["bias"] = migraph::generate_argument(vars);
return m;
}
};
void batch_norm_inference_test()
{
migraph::program p;
const size_t width = 2, height = 2, channels = 4, batches = 2;
const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f,
bias_val = 1.0f;
const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels);
std::vector<float> mean_data(channels);
std::vector<float> variance_data(channels);
std::fill(x_data.begin(), x_data.end(), x_val);
std::fill(mean_data.begin(), mean_data.end(), mean_val);
std::fill(variance_data.begin(), variance_data.end(), variance_val);
std::fill(scale_data.begin(), scale_data.end(), scale_val);
std::fill(bias_data.begin(), bias_data.end(), bias_val);
auto x = p.add_literal(migraph::literal{s, x_data});
auto scale = p.add_literal(migraph::literal{vars, scale_data});
auto bias = p.add_literal(migraph::literal{vars, bias_data});
auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
p.compile(migraph::gpu::target{});
auto result = p.eval({});
std::vector<float> result_vector(width * height * channels * batches);
std::vector<float> gold(width * height * channels * batches);
std::fill(gold.begin(), gold.end(), output_val);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(result_vector, gold));
}
int main()
{
verify_program<test_add>();
......@@ -224,4 +300,6 @@ int main()
verify_program<test_gemm>();
verify_program<test_contiguous>();
verify_program<test_transpose>();
// verify_program<test_batchnorm_inference>();
// batch_norm_inference_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment