Commit 48a85620 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add additional test for different output shapes for each branch

This test should error out as, having different output shapes for each
branch with one non empty is invalid.

Changes to gen_onnx.py as well as generated onnx file provided
parent 58496149
...@@ -2473,6 +2473,68 @@ def if_then_else_incompatible_shape_test2(): ...@@ -2473,6 +2473,68 @@ def if_then_else_incompatible_shape_test2():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_then_else_incompatible_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, [2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 2])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 2])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 2).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node], 'else_body', [],
[else_out, else_out2])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test @onnx_test
def if_then_else_diff_type_test(): def if_then_else_diff_type_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.INT64, [2]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.INT64, [2])
......
...@@ -2671,6 +2671,12 @@ TEST_CASE(if_else_then_both_empty) ...@@ -2671,6 +2671,12 @@ TEST_CASE(if_else_then_both_empty)
EXPECT(test::throws([&] { migraphx::parse_onnx("if_then_else_both_empty_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("if_then_else_both_empty_test.onnx"); }));
} }
TEST_CASE(if_else_then_incompatible_output_shape)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("if_then_else_incompatible_output_shape.onnx"); }));
}
TEST_CASE(if_then_test) TEST_CASE(if_then_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment