#include #include #include "test.hpp" TEST_CASE(load_and_run) { auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto shapes_before = p.get_output_shapes(); p.compile(migraphx::target("cpu")); auto shapes_after = p.get_output_shapes(); CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == shapes_after.size()); CHECK(bool{shapes_before.front() == shapes_after.front()}); migraphx::program_parameters pp; auto param_shapes = p.get_parameter_shapes(); for(auto&& name : param_shapes.names()) { pp.add(name, migraphx::argument::generate(param_shapes[name])); } auto outputs = p.eval(pp); CHECK(shapes_before.size() == outputs.size()); CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); } TEST_CASE(load_and_run_user_input_shape) { migraphx::onnx_options options; options.set_input_parameter_shape("0", {2, 3, 64, 64}); auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx", options); auto shapes_before = p.get_output_shapes(); p.compile(migraphx::target("cpu")); auto shapes_after = p.get_output_shapes(); CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == shapes_after.size()); CHECK(bool{shapes_before.front() == shapes_after.front()}); migraphx::program_parameters pp; auto param_shapes = p.get_parameter_shapes(); for(auto&& name : param_shapes.names()) { pp.add(name, migraphx::argument::generate(param_shapes[name])); } auto outputs = p.eval(pp); CHECK(shapes_before.size() == outputs.size()); CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); } TEST_CASE(zero_parameter) { auto p = migraphx::parse_onnx("constant_fill_test.onnx"); auto shapes_before = p.get_output_shapes(); p.compile(migraphx::target("cpu")); auto shapes_after = p.get_output_shapes(); CHECK(shapes_before.size() == 1); CHECK(shapes_before.size() == shapes_after.size()); CHECK(bool{shapes_before.front() == shapes_after.front()}); migraphx::program_parameters pp; auto param_shapes = p.get_parameter_shapes(); for(auto&& name : param_shapes.names()) { pp.add(name, migraphx::argument::generate(param_shapes[name])); } auto outputs = p.eval(pp); CHECK(shapes_before.size() == outputs.size()); CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); } TEST_CASE(set_scalar_parameter) { auto p1 = migraphx::parse_onnx("add_bcast_test.onnx"); migraphx::shape s1(migraphx_shape_float_type, {3, 4}); auto param_shapes = p1.get_parameter_shapes(); auto s1_orig = param_shapes["1"]; CHECK(bool{s1 == s1_orig}); migraphx::onnx_options option; option.set_input_parameter_shape("1", {}); auto p2 = migraphx::parse_onnx("add_bcast_test.onnx", option); migraphx::shape s_scalar(migraphx_shape_float_type); auto param_shapes_1 = p2.get_parameter_shapes(); auto s_scalar_after = param_shapes_1["1"]; CHECK(bool{s_scalar == s_scalar_after}); } int main(int argc, const char* argv[]) { test::run(argc, argv); }