"docs/source/tutorials/pruning_bert_glue.py.md5" did not exist on "3eca23d51936e7d81b3b0c5c6770ec38c538c072"
Commit e05b63f0 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix verify_onnx tests for updated if op trailing cases

Hand verified results before running test using random data generated from
the new protobuf files for this.

Everything is correctly functionally within tolerances.

Had to readjust input sizes to that of the .onnx file as one input is deliberately
made smaller (without the trailing 1) for theses tests to test the case correctly
parent 46383398
...@@ -506,10 +506,11 @@ TEST_CASE(if_then_trailing_one_shape_test) ...@@ -506,10 +506,11 @@ TEST_CASE(if_then_trailing_one_shape_test)
migraphx::program p = migraphx::parse_onnx("if_then_trailing_one_shape_test.onnx"); migraphx::program p = migraphx::parse_onnx("if_then_trailing_one_shape_test.onnx");
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 1}}; migraphx::shape s_data{migraphx::shape::float_type, {2, 1}};
migraphx::shape s_data_x{migraphx::shape::float_type, {2}};
std::vector<float> data = {0.0625, 0.75}; std::vector<float> data = {0.0625, 0.75};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data()); pp["x"] = migraphx::argument(s_data_x, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
...@@ -571,17 +572,18 @@ TEST_CASE(if_else_trailing_one_shape_test) ...@@ -571,17 +572,18 @@ TEST_CASE(if_else_trailing_one_shape_test)
migraphx::program p = migraphx::parse_onnx("if_else_trailing_one_shape_test.onnx"); migraphx::program p = migraphx::parse_onnx("if_else_trailing_one_shape_test.onnx");
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 1}}; migraphx::shape s_data{migraphx::shape::float_type, {2, 1}};
migraphx::shape s_data_y{migraphx::shape::float_type, {2}};
std::vector<float> data = {0.0625, 0.75}; std::vector<float> data = {0.0625, 0.75};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data()); pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data_y, data.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.0364609435, 0.475317657}; std::vector<float> gold = {0.002918556, 0.29198325};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment