"mmdet3d/datasets/vscode:/vscode.git/clone" did not exist on "6aab10da2f753f64fbac229e50134702b2808cce"
Commit 2ae4c715 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fix parse_if to handle multi output constant branches

- Initial fix to handle scalars on input for empty constant values
- Using scalar, multibroadcast, contiguous
- Fixed appropriate unit tests for simple single output constants
- Added unit tests for multi if outputs.

- TODO - multibroadcast to handle scalar so we don't use scalar
parent a27808b3
......@@ -127,10 +127,21 @@ struct parse_if : op_parser<parse_if>
throw_shapes();
}
auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
shape gen_shape(out_shape.type(), out_shape.lens(), out_shape.strides());
auto literal_ins = mdl->add_literal(gen_shape, gen_shape.lens());
mdl->replace_return({literal_ins});
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
shape gen_shape(shape(out_shape.type(), {1}, {0}));
auto literal_ins =
mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0}));
auto unsqueeze_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
literal_ins);
auto broad_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
unsqueeze_ins);
auto contig_out = mdl->insert_instruction(
std::prev(mdl->end()), make_op("contiguous"), broad_ins);
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
return out_shape.lens();
};
......@@ -138,11 +149,11 @@ struct parse_if : op_parser<parse_if>
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, else_out_shape);
then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
}
else if(else_lens.empty())
{
else_lens = handle_empty_branch(else_mdl, then_out_shape);
else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
}
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
......
......@@ -2844,6 +2844,77 @@ def if_else_empty_constant_test():
return ([node], [y], [res], [cond_tensor, empty_ts, yt_tensor])
@onnx_test
def if_else_empty_constant_multi_output_test():
empty_val = np.array([]).astype(np.int64)
empty_ts = helper.make_tensor(name='empty_tensor',
data_type=TensorProto.INT64,
dims=empty_val.shape,
vals=empty_val.flatten().astype(int))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape_const'],
value=empty_ts,
)
shape_const2 = helper.make_node(
'Constant',
inputs=[],
outputs=['shape_const2'],
value=empty_ts,
)
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
else_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
else_out2 = helper.make_tensor_value_info('shape_const2', TensorProto.INT64,
[])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.INT64,
[2, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.INT64,
[2, 1])
yt = np.random.randn(2, 1).astype(np.int64)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.INT64,
dims=yt.shape,
vals=yt.flatten().astype(np.int64))
then_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['then_out'])
then_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['then_out2'])
else_body = onnx.helper.make_graph([shape_const, shape_const2],
'else_body', [], [else_out, else_out2])
then_body = onnx.helper.make_graph([then_mul_node, then_sub_node],
'then_body', [], [then_out, then_out2])
cond = np.array([0]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.INT64, [2, 1])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.INT64, [2, 1])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [y], [res, res2], [cond_tensor, empty_ts, yt_tensor])
@onnx_test
def if_then_empty_constant_test():
......@@ -2899,6 +2970,78 @@ def if_then_empty_constant_test():
return ([node], [y], [res], [cond_tensor, empty_ts, yt_tensor])
@onnx_test
def if_then_empty_constant_multi_output_test():
empty_val = np.array([]).astype(np.int64)
empty_ts = helper.make_tensor(name='empty_tensor',
data_type=TensorProto.INT64,
dims=empty_val.shape,
vals=empty_val.flatten().astype(int))
shape_const = helper.make_node(
'Constant',
inputs=[],
outputs=['shape_const'],
value=empty_ts,
)
shape_const2 = helper.make_node(
'Constant',
inputs=[],
outputs=['shape_const2'],
value=empty_ts,
)
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.INT64, [2, 1])
then_out = helper.make_tensor_value_info('shape_const', TensorProto.INT64,
[])
then_out2 = helper.make_tensor_value_info('shape_const2', TensorProto.INT64,
[])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.INT64,
[2, 1])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.INT64,
[2, 1])
yt = np.random.randn(2, 1).astype(np.int64)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.INT64,
dims=yt.shape,
vals=yt.flatten().astype(np.int64))
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
then_body = onnx.helper.make_graph([shape_const, shape_const2],
'then_body', [], [then_out, then_out2])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res = onnx.helper.make_tensor_value_info('res', TensorProto.INT64, [2, 1])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.INT64, [2, 1])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [y], [res, res2], [cond_tensor, empty_ts, yt_tensor])
@onnx_test
def if_literal_test():
then_out = onnx.helper.make_tensor_value_info('then_out',
......
......@@ -2599,11 +2599,24 @@ TEST_CASE(if_then_empty_constant_test)
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_s{s.type(), s.lens(), s.strides()};
auto then_lit = then_mod->add_literal(gen_s, gen_s.lens());
auto then_lit = then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_return({then_lit});
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
then_mod->insert_literal((std::prev(then_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = then_mod->insert_instruction(
std::prev(then_mod->end()), migraphx::make_op("contiguous"), broad_ins);
then_mod->replace_instruction(std::prev(then_mod->end())->inputs().at(0), contig_out);
auto* else_mod = p.create_module("If_4_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
......@@ -2616,6 +2629,68 @@ TEST_CASE(if_then_empty_constant_test)
EXPECT(p == prog);
}
TEST_CASE(if_then_empty_constant_multi_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::int64_type, {2, 1}};
std::vector<int> rand = {0, -1};
mm->add_literal(migraphx::shape::int64_type);
auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
auto lit = then_mod->add_literal(migraphx::shape::int64_type);
auto lit2 = then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_return({lit, lit2});
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
then_mod->insert_literal((std::prev(then_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = then_mod->insert_instruction(
std::prev(then_mod->end()), migraphx::make_op("contiguous"), broad_ins);
then_mod->replace_instruction(std::prev(then_mod->end())->inputs().at(0), contig_out);
migraphx::shape gen_shape2(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins2 =
then_mod->insert_literal(std::prev(then_mod->end()), migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins2);
auto broad_ins2 =
then_mod->insert_instruction(std::prev(then_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins2);
auto contig_out2 = then_mod->insert_instruction(
std::prev(then_mod->end()), migraphx::make_op("contiguous"), broad_ins2);
then_mod->replace_instruction(std::prev(then_mod->end())->inputs().at(1), contig_out2);
auto* else_mod = p.create_module("If_4_else");
auto mul = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto sub = else_mod->add_instruction(migraphx::make_op("sub"), y, l2);
else_mod->add_return({mul, sub});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r, r2});
auto prog = migraphx::parse_onnx("if_then_empty_constant_multi_output_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_else_empty_constant_test)
{
migraphx::program p;
......@@ -2633,11 +2708,24 @@ TEST_CASE(if_else_empty_constant_test)
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_s{s.type(), s.lens(), s.strides()};
auto else_lit = else_mod->add_literal(gen_s, gen_s.lens());
auto else_lit = else_mod->add_literal(s.type());
else_mod->add_return({else_lit});
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
else_mod->insert_literal((std::prev(else_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = else_mod->insert_instruction(
std::prev(else_mod->end()), migraphx::make_op("contiguous"), broad_ins);
else_mod->replace_instruction(std::prev(else_mod->end())->inputs().at(0), contig_out);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
......@@ -2646,6 +2734,68 @@ TEST_CASE(if_else_empty_constant_test)
EXPECT(p == prog);
}
TEST_CASE(if_else_empty_constant_multi_output_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::int64_type, {2, 1}};
std::vector<int> rand = {-1, 0};
mm->add_literal(migraphx::shape::int64_type);
auto l2 = mm->add_literal(s, rand);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_4_if");
auto mul = then_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto sub = then_mod->add_instruction(migraphx::make_op("sub"), y, l2);
then_mod->add_return({mul, sub});
auto* else_mod = p.create_module("If_4_else");
auto lit = else_mod->add_literal(migraphx::shape::int64_type);
auto lit2 = else_mod->add_literal(migraphx::shape::int64_type);
else_mod->add_return({lit, lit2});
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins =
else_mod->insert_literal((std::prev(else_mod->end())), migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins);
auto broad_ins =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins);
auto contig_out = else_mod->insert_instruction(
std::prev(else_mod->end()), migraphx::make_op("contiguous"), broad_ins);
else_mod->replace_instruction(std::prev(else_mod->end())->inputs().at(0), contig_out);
migraphx::shape gen_shape2(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins2 =
else_mod->insert_literal(std::prev(else_mod->end()), migraphx::literal(gen_shape2, {0}));
auto unsqueeze_ins2 =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}),
literal_ins2);
auto broad_ins2 =
else_mod->insert_instruction(std::prev(else_mod->end()),
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}),
unsqueeze_ins2);
auto contig_out2 = else_mod->insert_instruction(
std::prev(else_mod->end()), migraphx::make_op("contiguous"), broad_ins2);
else_mod->replace_instruction(std::prev(else_mod->end())->inputs().at(1), contig_out2);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r, r2});
auto prog = migraphx::parse_onnx("if_else_empty_constant_multi_output_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment