Commit ea87f25e authored by Paul's avatar Paul
Browse files

Add a test for a deeper network

parent 946dc37a
......@@ -133,7 +133,7 @@ struct convolution
struct pooling
{
std::string mode;
std::string mode = "average";
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
......
......@@ -41,17 +41,22 @@ int main(int argc, char const* argv[])
if(argc > 1)
{
std::string file = argv[1];
auto p = migraph::parse_onnx(file);
std::cout << p << std::endl;
auto x = run_cpu(file);
auto y = run_gpu(file);
visit_all(x, y)([](auto cpu, auto gpu) {
if(migraph::verify_range(cpu, gpu))
if(migraph::verify_range(cpu, gpu, 1))
{
std::cout << "Passed" << std::endl;
}
else
{
std::cout << "Not equal" << std::endl;
std::cout << "cpu:" << std::endl;
std::cout << cpu << std::endl;
std::cout << "gpu:" << std::endl;
std::cout << gpu << std::endl;
}
......
......@@ -43,6 +43,25 @@ std::future<typename std::result_of<Function()>::type> detach_async(Function&& f
struct auto_print
{
static void set_terminate_handler(const std::string& name)
{
static std::string pname;
pname = name;
std::set_terminate(+[] {
std::cout << "FAILED: " << pname << std::endl;
try
{
std::rethrow_exception(std::current_exception());
}
catch(const std::exception& e)
{
std::cout << " what(): " << e.what() << std::endl;
}
std::cout << std::endl;
for(auto&& handle : auto_print::handlers)
handle();
});
}
static std::array<std::function<void()>, 2> handlers;
int index;
template <class T>
......@@ -103,30 +122,13 @@ migraph::argument run_gpu()
return migraph::gpu::from_gpu(p.eval(m));
}
template <class V>
void verify_program()
void verify_args(const std::string& name, const migraph::argument& cpu_arg, const migraph::argument& gpu_arg)
{
std::set_terminate(+[] {
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
try
{
std::rethrow_exception(std::current_exception());
}
catch(const std::exception& e)
{
std::cout << " what(): " << e.what() << std::endl;
}
std::cout << std::endl;
for(auto&& handle : auto_print::handlers)
handle();
});
auto cpu_arg_f = detach_async([] { return run_cpu<V>(); });
auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg_f.get(), gpu_arg)([](auto cpu, auto gpu) {
visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) {
if(not migraph::verify_range(cpu, gpu))
{
// TODO: Check for nans
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
std::cout << "FAILED: " << name << std::endl;
// std::cout << cpu << std::endl;
// std::cout << gpu << std::endl;
if(migraph::range_zero(cpu))
......@@ -152,6 +154,15 @@ void verify_program()
<< gpu[gpu_nan_idx] << std::endl;
}
});
}
template <class V>
void verify_program()
{
auto_print::set_terminate_handler(migraph::get_type_name<V>());
auto cpu_arg_f = detach_async([] { return run_cpu<V>(); });
auto gpu_arg = run_gpu<V>();
verify_args(migraph::get_type_name<V>(), cpu_arg_f.get(), gpu_arg);
std::set_terminate(nullptr);
}
......@@ -364,6 +375,29 @@ struct test_batchnorm_inference
}
};
struct test_conv_bn_relu_pooling
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}};
migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}};
auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
migraph::shape vars{migraph::shape::float_type, {64}};
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto bn = p.add_instruction(migraph::batch_norm_inference{}, conv, mean, variance, scale, bias);
auto relu = p.add_instruction(migraph::activation{"relu"}, bn);
p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p;
}
};
int main()
{
verify_program<test_add>();
......@@ -379,4 +413,5 @@ int main()
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn_relu_pooling>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment