Commit d8ee02b9 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

More review changes/fixes

- Handle checks for each IF output
- add const to inputs of all_but_last_dims_equal
- add std::equal instead of using equal
- Use .back() for vectors in getting last value
- Use input().front() instead of prev(prev()) when replacing the last value.
parent 34aeda23
...@@ -78,92 +78,97 @@ struct parse_if : op_parser<parse_if> ...@@ -78,92 +78,97 @@ struct parse_if : op_parser<parse_if>
throw_shapes(); throw_shapes();
} }
// Must have the same type for both if/else blocks by onnx spec // Add checks for each output shape
// Add exception for empty constant scalars for(int i = 0; i < then_out_shapes.size(); i++)
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type())
{ {
MIGRAPHX_THROW("PARSE_IF: " + info.name + // Must have the same type for both if/else blocks by onnx spec
" then and else sub_grahps must have same output type! " + if(then_out_shapes.at(i).type() != else_out_shapes.at(i).type())
then_out_shapes.at(0).type_string() + " vs " +
else_out_shapes.at(0).type_string());
}
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{
auto then_lens = then_out_shapes.at(0).lens();
auto else_lens = else_out_shapes.at(0).lens();
// Throw error if both branches have zero output shapes. Not possible for static inputs
if(then_lens.empty() && else_lens.empty())
{ {
throw_shapes(); MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " +
then_out_shapes.at(i).type_string() + " vs " +
else_out_shapes.at(i).type_string());
} }
auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) { if(not then_out_shapes.at(i).dynamic() && not else_out_shapes.at(i).dynamic())
auto outline_ins = mdl->add_outline(out_shape);
mdl->replace_return({outline_ins});
return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(0));
}
else if(else_lens.empty())
{ {
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0)); auto then_lens = then_out_shapes.at(i).lens();
} auto else_lens = else_out_shapes.at(i).lens();
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn) // Throw error if both branches have zero output shapes. Not possible for static
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size()))); // inputs
if(then_lens.empty() && else_lens.empty())
{
throw_shapes();
}
if(dim_delta == 1) auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
{ auto outline_ins = mdl->add_outline(out_shape);
auto all_but_last_dims_equal = [](std::vector<size_t>& lens_a, mdl->replace_return({outline_ins});
std::vector<size_t>& lens_b) { return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(i));
}
else if(else_lens.empty())
{
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(i));
}
auto all_but_last_dims_equal = [](const std::vector<size_t>& lens_a,
const std::vector<size_t>& lens_b) {
if(lens_a.size() <= lens_b.size()) if(lens_a.size() <= lens_b.size())
{ {
return equal(lens_a.begin(), lens_a.end(), lens_b.begin()); return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
} }
else else
{ {
return equal(lens_b.begin(), lens_b.end(), lens_a.begin()); return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
} }
}; };
// make sure dims are equivalent in static shapes // check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
if(not all_but_last_dims_equal(then_lens, else_lens)) int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta == 1)
{ {
throw_shapes(); // make sure dims are equivalent in static shapes
} if(not all_but_last_dims_equal(then_lens, else_lens))
{
throw_shapes();
}
auto unsqueeze_last_op = [](module_ref& mdl, const std::vector<size_t>& out_shape) { auto last_then = then_lens.back();
auto convert_ins = mdl->add_instruction( auto last_else = else_lens.back();
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(std::prev(mdl->end())));
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
auto last_then = *(std::prev(then_lens.end())); auto unsqueeze_last_op = [](module_ref& mdl,
auto last_else = *(std::prev(else_lens.end())); const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(mdl->end())->inputs().front());
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
// Find which dim to unsqueeze // Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1)) if((then_lens.size() < else_lens.size()) && (last_else == 1))
{ {
unsqueeze_last_op(then_mdl, else_lens); unsqueeze_last_op(then_mdl, else_lens);
}
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
unsqueeze_last_op(else_mdl, then_lens);
}
} }
else if((then_lens.size() > else_lens.size()) && (last_then == 1)) else if(dim_delta > 1)
{ {
unsqueeze_last_op(else_mdl, then_lens); throw_shapes();
} }
} }
else if(dim_delta > 1)
{
throw_shapes();
}
} }
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
...@@ -2382,15 +2382,7 @@ TEST_CASE(if_else_empty_shape_test) ...@@ -2382,15 +2382,7 @@ TEST_CASE(if_else_empty_shape_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_empty_shape_test.onnx", std::ios::binary); auto prog = migraphx::parse_onnx("if_else_empty_shape_test.onnx");
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -2422,15 +2414,7 @@ TEST_CASE(if_else_trailing_one_shape_test) ...@@ -2422,15 +2414,7 @@ TEST_CASE(if_else_trailing_one_shape_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_trailing_one_shape_test.onnx", std::ios::binary); auto prog = migraphx::parse_onnx("if_else_trailing_one_shape_test.onnx");
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
EXPECT(p == prog); EXPECT(p == prog);
} }
......
...@@ -511,7 +511,7 @@ TEST_CASE(if_then_test) ...@@ -511,7 +511,7 @@ TEST_CASE(if_then_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -533,7 +533,7 @@ TEST_CASE(if_then_empty_shape_test) ...@@ -533,7 +533,7 @@ TEST_CASE(if_then_empty_shape_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.1337, 1.1337, 1.1337, 1.1337, 1.1337, 1.1337}; std::vector<float> gold = {1.1337, 1.1337, 1.1337, 1.1337, 1.1337, 1.1337};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -554,7 +554,7 @@ TEST_CASE(if_then_trailing_one_shape_test) ...@@ -554,7 +554,7 @@ TEST_CASE(if_then_trailing_one_shape_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75}; std::vector<float> gold = {1.0625, 1.75};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -597,8 +597,8 @@ TEST_CASE(if_else_empty_shape_test) ...@@ -597,8 +597,8 @@ TEST_CASE(if_else_empty_shape_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff multiplies things by a random vector that's baked in. // onnx multiplies things by a random vector that's baked in.
// Needs to be changed everytime we refresh the protobuf // Needs to be changed everytime we refresh the onnx file
std::vector<float> gold = {0.764314, 1.05549, -3.59435, -2.3556, -0.611802, -0.0784514}; std::vector<float> gold = {0.764314, 1.05549, -3.59435, -2.3556, -0.611802, -0.0784514};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment