Commit 16d016a9 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Added another check to a verify_onnx test. Dynamic and static input versions...

Added another check to a verify_onnx test.  Dynamic and static input versions with same attributes and inputs go through different code paths but should give same result.
parent 88a4a3ef
...@@ -6626,6 +6626,28 @@ def resize_downsample_f_dyn_test(): ...@@ -6626,6 +6626,28 @@ def resize_downsample_f_dyn_test():
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_downsample_f_ref_test():
# Same as resize_downsample_f_dyn_test but with static input
scales = np.array([1.0, 1.0, 0.601, 0.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [2, 1, 5, 9])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='asymmetric',
mode='nearest',
nearest_mode='floor')
return ([node], [X], [Y], [scale_tensor])
@onnx_test() @onnx_test()
def resize_upsample_f_dyn_test(): def resize_upsample_f_dyn_test():
scales = np.array([1.0, 1.0, 1.601, 1.601], dtype=np.float32) scales = np.array([1.0, 1.0, 1.601, 1.601], dtype=np.float32)
......
...@@ -1796,6 +1796,11 @@ TEST_CASE(resize_downsample_f_dyn_test) ...@@ -1796,6 +1796,11 @@ TEST_CASE(resize_downsample_f_dyn_test)
auto p = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options); auto p = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
// A Resize op. with static input shape goes through a different code path
// but should give same result
auto reference_p = migraphx::parse_onnx("resize_downsample_f_ref_test.onnx", options);
reference_p.compile(migraphx::make_target("ref"));
migraphx::shape sx{migraphx::shape::float_type, {2, 1, 5, 9}}; migraphx::shape sx{migraphx::shape::float_type, {2, 1, 5, 9}};
std::vector<float> dx(sx.elements()); std::vector<float> dx(sx.elements());
std::iota(dx.begin(), dx.end(), 0.1f); std::iota(dx.begin(), dx.end(), 0.1f);
...@@ -1819,6 +1824,12 @@ TEST_CASE(resize_downsample_f_dyn_test) ...@@ -1819,6 +1824,12 @@ TEST_CASE(resize_downsample_f_dyn_test)
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector, EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector,
migraphx::verify::expected{gold})); migraphx::verify::expected{gold}));
auto reference_result = reference_p.eval(pp).back();
std::vector<float> reference_vector;
reference_result.visit([&](auto output) { reference_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector,
migraphx::verify::expected{reference_vector}));
} }
TEST_CASE(resize_upsample_f_dyn_test) TEST_CASE(resize_upsample_f_dyn_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment