Commit d8ee02b9 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

More review changes/fixes

- Handle checks for each IF output
- add const to inputs of all_but_last_dims_equal
- add std::equal instead of using equal
- Use .back() for vectors in getting last value
- Use input().front() instead of prev(prev()) when replacing the last value.
parent 34aeda23
...@@ -78,22 +78,25 @@ struct parse_if : op_parser<parse_if> ...@@ -78,22 +78,25 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
// Add checks for each output shape
for(int i = 0; i < then_out_shapes.size(); i++)
{
// Must have the same type for both if/else blocks by onnx spec // Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars if(then_out_shapes.at(i).type() != else_out_shapes.at(i).type())
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type())
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " + " then and else sub_grahps must have same output type! " +
then_out_shapes.at(0).type_string() + " vs " + then_out_shapes.at(i).type_string() + " vs " +
else_out_shapes.at(0).type_string()); else_out_shapes.at(i).type_string());
} }
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic()) if(not then_out_shapes.at(i).dynamic() && not else_out_shapes.at(i).dynamic())
{ {
auto then_lens = then_out_shapes.at(0).lens(); auto then_lens = then_out_shapes.at(i).lens();
auto else_lens = else_out_shapes.at(0).lens(); auto else_lens = else_out_shapes.at(i).lens();
// Throw error if both branches have zero output shapes. Not possible for static inputs // Throw error if both branches have zero output shapes. Not possible for static
// inputs
if(then_lens.empty() && else_lens.empty()) if(then_lens.empty() && else_lens.empty())
{ {
throw_shapes(); throw_shapes();
...@@ -109,47 +112,48 @@ struct parse_if : op_parser<parse_if> ...@@ -109,47 +112,48 @@ struct parse_if : op_parser<parse_if>
// need to update the then_shape before we do further checks // need to update the then_shape before we do further checks
if(then_lens.empty()) if(then_lens.empty())
{ {
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(0)); then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(i));
} }
else if(else_lens.empty()) else if(else_lens.empty())
{ {
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0)); else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(i));
} }
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) auto all_but_last_dims_equal = [](const std::vector<size_t>& lens_a,
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size()))); const std::vector<size_t>& lens_b) {
if(dim_delta == 1)
{
auto all_but_last_dims_equal = [](std::vector<size_t>& lens_a,
std::vector<size_t>& lens_b) {
if(lens_a.size() <= lens_b.size()) if(lens_a.size() <= lens_b.size())
{ {
return equal(lens_a.begin(), lens_a.end(), lens_b.begin()); return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
} }
else else
{ {
return equal(lens_b.begin(), lens_b.end(), lens_a.begin()); return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
} }
}; };
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta == 1)
{
// make sure dims are equivalent in static shapes // make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens)) if(not all_but_last_dims_equal(then_lens, else_lens))
{ {
throw_shapes(); throw_shapes();
} }
auto unsqueeze_last_op = [](module_ref& mdl, const std::vector<size_t>& out_shape) { auto last_then = then_lens.back();
auto last_else = else_lens.back();
auto unsqueeze_last_op = [](module_ref& mdl,
const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction( auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}), make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(std::prev(mdl->end()))); std::prev(mdl->end())->inputs().front());
mdl->replace_return({convert_ins}); mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)}); mdl->remove_instruction({std::prev(convert_ins)});
}; };
auto last_then = *(std::prev(then_lens.end()));
auto last_else = *(std::prev(else_lens.end()));
// Find which dim to unsqueeze // Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1)) if((then_lens.size() < else_lens.size()) && (last_else == 1))
{ {
...@@ -165,6 +169,7 @@ struct parse_if : op_parser<parse_if> ...@@ -165,6 +169,7 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
} }
}
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
auto out_s = if_ret->get_shape(); auto out_s = if_ret->get_shape();
......
...@@ -2382,15 +2382,7 @@ TEST_CASE(if_else_empty_shape_test) ...@@ -2382,15 +2382,7 @@ TEST_CASE(if_else_empty_shape_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_empty_shape_test.onnx", std::ios::binary); auto prog = migraphx::parse_onnx("if_else_empty_shape_test.onnx");
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -2422,15 +2414,7 @@ TEST_CASE(if_else_trailing_one_shape_test) ...@@ -2422,15 +2414,7 @@ TEST_CASE(if_else_trailing_one_shape_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_trailing_one_shape_test.onnx", std::ios::binary); auto prog = migraphx::parse_onnx("if_else_trailing_one_shape_test.onnx");
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -511,7 +511,7 @@ TEST_CASE(if_then_test) ...@@ -511,7 +511,7 @@ TEST_CASE(if_then_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -533,7 +533,7 @@ TEST_CASE(if_then_empty_shape_test) ...@@ -533,7 +533,7 @@ TEST_CASE(if_then_empty_shape_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.1337, 1.1337, 1.1337, 1.1337, 1.1337, 1.1337}; std::vector<float> gold = {1.1337, 1.1337, 1.1337, 1.1337, 1.1337, 1.1337};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -554,7 +554,7 @@ TEST_CASE(if_then_trailing_one_shape_test) ...@@ -554,7 +554,7 @@ TEST_CASE(if_then_trailing_one_shape_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75}; std::vector<float> gold = {1.0625, 1.75};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -597,8 +597,8 @@ TEST_CASE(if_else_empty_shape_test) ...@@ -597,8 +597,8 @@ TEST_CASE(if_else_empty_shape_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff multiplies things by a random vector that's baked in. // onnx multiplies things by a random vector that's baked in.
// Needs to be changed everytime we refresh the protobuf // Needs to be changed everytime we refresh the onnx file
std::vector<float> gold = {0.764314, 1.05549, -3.59435, -2.3556, -0.611802, -0.0784514}; std::vector<float> gold = {0.764314, 1.05549, -3.59435, -2.3556, -0.611802, -0.0784514};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment