Commit bdf6c835 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add unit tests for empty constants as input to if branches

- gen_onnx.py changes for onnx output of empty const input branches (seen in resnext50)
- updated onnx_test.cpp to validate parsing of input.
- new onnx files generated from onnx tests
parent d8ee02b9
......@@ -2314,6 +2314,114 @@ def if_then_trailing_one_shape_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test
def if_else_empty_constant_test():
empty_val = np.array([]).astype(np.int64)
empty_ts = helper.make_tensor(name='empty_tensor',
data_type=TensorProto.INT64,
dims=empty_val.shape,
vals=empty_val.flatten().astype(int))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape_const'],
value=empty_ts,
)
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, [])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.INT64,
[2, 1])
yt = np.random.randn(2, 1).astype(np.int64)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.INT64,
dims=yt.shape,
vals=yt.flatten().astype(np.int64))
then_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['then_out'])
else_body = onnx.helper.make_graph([shape_const], 'else_body', [],
[else_out])
then_body = onnx.helper.make_graph([then_mul_node], 'then_body', [],
[then_out])
cond = np.array([0]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.INT64, [2, 1])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [y], [res], [cond_tensor, empty_ts, yt_tensor])
@onnx_test
def if_then_empty_constant_test():
empty_val = np.array([]).astype(np.int64)
empty_ts = helper.make_tensor(name='empty_tensor',
data_type=TensorProto.INT64,
dims=empty_val.shape,
vals=empty_val.flatten().astype(int))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape_const'],
value=empty_ts,
)
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64, [])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.INT64,
[2, 1])
yt = np.random.randn(2, 1).astype(np.int64)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.INT64,
dims=yt.shape,
vals=yt.flatten().astype(np.int64))
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([shape_const], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.INT64, [2, 1])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [y], [res], [cond_tensor, empty_ts, yt_tensor])
@onnx_test
def if_literal_test():
then_out = onnx.helper.make_tensor_value_info('then_out',
......
......@@ -2586,6 +2586,64 @@ TEST_CASE(if_then_trailing_one_shape_test)
EXPECT(p == prog);
}
TEST_CASE(if_then_empty_constant_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::int64_type, {2, 1}};
std::vector<int> rand = {-1, 0};
mm->add_literal(migraphx::shape::int64_type);
auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
auto outline = then_mod->add_outline(s);
then_mod->add_return({outline});
auto* else_mod = p.create_module("If_4_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_then_empty_constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_else_empty_constant_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::int64_type, {2, 1}};
std::vector<int> rand = {1, -2};
mm->add_literal(migraphx::shape::int64_type);
auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
auto rt = then_mod->add_instruction(migraphx::make_op("mul"), y, l2);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(migraphx::shape::int64_type);
auto outline = else_mod->add_outline(s);
else_mod->add_return({outline});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto prog = migraphx::parse_onnx("if_else_empty_constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_then_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment