Commit 58496149 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add another unit test for incompatible shapes

Add the testcase and onnx file generated to handle the case of two output
shapes that vary in rank by one, with a trailing 1 but sub lengths are not equivalent
parent c3189eaa
......@@ -2420,6 +2420,59 @@ def if_then_else_incompatible_shape_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_then_else_incompatible_shape_test2():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 2])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 2).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_then_else_diff_type_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.INT64, [2])
......
......@@ -2655,6 +2655,12 @@ TEST_CASE(if_then_else_incompatible_shape)
test::throws([&] { migraphx::parse_onnx("if_then_else_incompatible_shape_test.onnx"); }));
}
TEST_CASE(if_then_else_incompatible_shape2)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("if_then_else_incompatible_shape_test2.onnx"); }));
}
TEST_CASE(if_else_then_diff_types)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("if_then_else_diff_type_test.onnx"); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment