Commit abd3d63e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Get empty shapes working for parse_IF operator

- Update if_then/else_empty test protobuff and cases
- Need to update rand() vector used
- Make y empty instead of x for if_else_empty_test.onnx
- Regenerate protobufs with updates
- Add changes to handle empty/scalar input branch size to if operator.
- Add case where if both branches empty throw an error.
- Update verify tests with gold vectors and new shapes for empty input vec
  which we handle like a scalar before broadcasting
parent 7c8c3bee
...@@ -84,13 +84,31 @@ struct parse_if : op_parser<parse_if> ...@@ -84,13 +84,31 @@ struct parse_if : op_parser<parse_if>
{ {
auto then_shape = then_out_shapes.at(0).lens(); auto then_shape = then_out_shapes.at(0).lens();
auto else_shape = else_out_shapes.at(0).lens(); auto else_shape = else_out_shapes.at(0).lens();
int dim_delta = abs((static_cast<int>(then_shape.size() - else_shape.size())));
auto throw_shapes = [&]() { auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must compatible shapes "); " then and else sub_graphs must compatible shapes ");
}; };
// Throw error if both branches have zero output shapes. Not possible for static inputs
if(then_out_shapes.at(0).elements() == 0 && else_out_shapes.at(0).elements() == 0)
{
throw_shapes();
}
// Handle one empty branch by setting output identical to the other
if(then_out_shapes.at(0).elements() == 0)
{
then_mdl->add_outline(else_out_shapes.at(0));
}
if(else_out_shapes.at(0).elements() == 0)
{
else_mdl->add_outline(then_out_shapes.at(0));
}
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) // check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_shape.size() - else_shape.size())));
if(dim_delta <= 1) if(dim_delta <= 1)
{ {
// make sure dims are equivalent in static shapes // make sure dims are equivalent in static shapes
......
...@@ -2208,8 +2208,8 @@ def if_else_test(): ...@@ -2208,8 +2208,8 @@ def if_else_test():
@onnx_test @onnx_test
def if_else_empty_shape_test(): def if_else_empty_shape_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, []) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
......
...@@ -2341,20 +2341,23 @@ TEST_CASE(if_else_empty_shape_test) ...@@ -2341,20 +2341,23 @@ TEST_CASE(if_else_empty_shape_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0})); auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s_else{migraphx::shape::float_type, {1}, {0}};
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589}; std::vector<float> rand = {0.382157, 0.527744, -1.79717, -1.1778, -0.305901, -0.0392257};
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s_else);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt}); then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto broad_y =
else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), y);
auto re = else_mod->add_instruction(migraphx::make_op("mul"), broad_y, l2);
else_mod->add_return({re}); else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
...@@ -2521,15 +2524,18 @@ TEST_CASE(if_then_empty_shape_test) ...@@ -2521,15 +2524,18 @@ TEST_CASE(if_then_empty_shape_test)
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1})); auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape s_then{migraphx::shape::float_type, {1}, {0}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946}; std::vector<float> rand = {1.0483, 0.687102, -1.7479, 1.59687, -0.0965695, -0.728357};
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s_then);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto broad_x =
then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3}}}), x);
auto rt = then_mod->add_instruction(migraphx::make_op("add"), broad_x, l1);
then_mod->add_return({rt}); then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
...@@ -2574,7 +2580,6 @@ TEST_CASE(if_then_trailing_one_shape_test) ...@@ -2574,7 +2580,6 @@ TEST_CASE(if_then_trailing_one_shape_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(if_then_test) TEST_CASE(if_then_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -485,9 +485,11 @@ TEST_CASE(if_then_empty_shape_test) ...@@ -485,9 +485,11 @@ TEST_CASE(if_then_empty_shape_test)
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape s_data_x{migraphx::shape::float_type, {1}, {0}};
std::vector<float> data_x = {0.1337};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data()); pp["x"] = migraphx::argument(s_data_x, data_x.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
...@@ -495,7 +497,7 @@ TEST_CASE(if_then_empty_shape_test) ...@@ -495,7 +497,7 @@ TEST_CASE(if_then_empty_shape_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // protobuff adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.1337, 1.1337, 1.1337, 1.1337, 1.1337, 1.1337};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -546,16 +548,21 @@ TEST_CASE(if_else_empty_shape_test) ...@@ -546,16 +548,21 @@ TEST_CASE(if_else_empty_shape_test)
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape s_data_y{migraphx::shape::float_type, {1}, {0}};
std::vector<float> data_y = {2.0};
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data()); pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data_y, data_y.data());
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = { // protobuff multiplies things by a random vector that's baked in.
-0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557}; // Needs to be changed everytime we refresh the protobuf
std::vector<float> gold = {0.764314, 1.05549, -3.59435, -2.3556, -0.611802, -0.0784514};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment