Commit d8ee02b9 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

More review changes/fixes

- Handle checks for each IF output
- add const to inputs of all_but_last_dims_equal
- add std::equal instead of using equal
- Use .back() for vectors in getting last value
- Use input().front() instead of prev(prev()) when replacing the last value.
parent 34aeda23
......@@ -78,92 +78,97 @@ struct parse_if : op_parser<parse_if>
throw_shapes();
}
// Must have the same type for both if/else blocks by onnx spec
// Add exception for empty constant scalars
if(then_out_shapes.at(0).type() != else_out_shapes.at(0).type())
// Add checks for each output shape
for(int i = 0; i < then_out_shapes.size(); i++)
{
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " +
then_out_shapes.at(0).type_string() + " vs " +
else_out_shapes.at(0).type_string());
}
if(not then_out_shapes.at(0).dynamic() && not else_out_shapes.at(0).dynamic())
{
auto then_lens = then_out_shapes.at(0).lens();
auto else_lens = else_out_shapes.at(0).lens();
// Throw error if both branches have zero output shapes. Not possible for static inputs
if(then_lens.empty() && else_lens.empty())
// Must have the same type for both if/else blocks by onnx spec
if(then_out_shapes.at(i).type() != else_out_shapes.at(i).type())
{
throw_shapes();
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_grahps must have same output type! " +
then_out_shapes.at(i).type_string() + " vs " +
else_out_shapes.at(i).type_string());
}
auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
auto outline_ins = mdl->add_outline(out_shape);
mdl->replace_return({outline_ins});
return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(0));
}
else if(else_lens.empty())
if(not then_out_shapes.at(i).dynamic() && not else_out_shapes.at(i).dynamic())
{
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(0));
}
auto then_lens = then_out_shapes.at(i).lens();
auto else_lens = else_out_shapes.at(i).lens();
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
// Throw error if both branches have zero output shapes. Not possible for static
// inputs
if(then_lens.empty() && else_lens.empty())
{
throw_shapes();
}
if(dim_delta == 1)
{
auto all_but_last_dims_equal = [](std::vector<size_t>& lens_a,
std::vector<size_t>& lens_b) {
auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
auto outline_ins = mdl->add_outline(out_shape);
mdl->replace_return({outline_ins});
return out_shape.lens();
};
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
{
then_lens = handle_empty_branch(then_mdl, else_out_shapes.at(i));
}
else if(else_lens.empty())
{
else_lens = handle_empty_branch(else_mdl, then_out_shapes.at(i));
}
auto all_but_last_dims_equal = [](const std::vector<size_t>& lens_a,
const std::vector<size_t>& lens_b) {
if(lens_a.size() <= lens_b.size())
{
return equal(lens_a.begin(), lens_a.end(), lens_b.begin());
return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
}
else
{
return equal(lens_b.begin(), lens_b.end(), lens_a.begin());
return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
}
};
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int dim_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
if(dim_delta == 1)
{
throw_shapes();
}
// make sure dims are equivalent in static shapes
if(not all_but_last_dims_equal(then_lens, else_lens))
{
throw_shapes();
}
auto unsqueeze_last_op = [](module_ref& mdl, const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(std::prev(mdl->end())));
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
auto last_then = then_lens.back();
auto last_else = else_lens.back();
auto last_then = *(std::prev(then_lens.end()));
auto last_else = *(std::prev(else_lens.end()));
auto unsqueeze_last_op = [](module_ref& mdl,
const std::vector<size_t>& out_shape) {
auto convert_ins = mdl->add_instruction(
make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
std::prev(mdl->end())->inputs().front());
mdl->replace_return({convert_ins});
mdl->remove_instruction({std::prev(convert_ins)});
};
// Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
unsqueeze_last_op(then_mdl, else_lens);
// Find which dim to unsqueeze
if((then_lens.size() < else_lens.size()) && (last_else == 1))
{
unsqueeze_last_op(then_mdl, else_lens);
}
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
{
unsqueeze_last_op(else_mdl, then_lens);
}
}
else if((then_lens.size() > else_lens.size()) && (last_then == 1))
else if(dim_delta > 1)
{
unsqueeze_last_op(else_mdl, then_lens);
throw_shapes();
}
}
else if(dim_delta > 1)
{
throw_shapes();
}
}
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
......
......@@ -2382,15 +2382,7 @@ TEST_CASE(if_else_empty_shape_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
std::ifstream ifs("if_else_empty_shape_test.onnx", std::ios::binary);
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
auto prog = migraphx::parse_onnx("if_else_empty_shape_test.onnx");
EXPECT(p == prog);
}
......@@ -2422,15 +2414,7 @@ TEST_CASE(if_else_trailing_one_shape_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
std::ifstream ifs("if_else_trailing_one_shape_test.onnx", std::ios::binary);
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
auto prog = migraphx::parse_onnx("if_else_trailing_one_shape_test.onnx");
EXPECT(p == prog);
}
......
......@@ -511,7 +511,7 @@ TEST_CASE(if_then_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold));
}
......@@ -533,7 +533,7 @@ TEST_CASE(if_then_empty_shape_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.1337, 1.1337, 1.1337, 1.1337, 1.1337, 1.1337};
EXPECT(migraphx::verify_range(result_vector, gold));
}
......@@ -554,7 +554,7 @@ TEST_CASE(if_then_trailing_one_shape_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff adds ones so result should be just + 1.0
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75};
EXPECT(migraphx::verify_range(result_vector, gold));
}
......@@ -597,8 +597,8 @@ TEST_CASE(if_else_empty_shape_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// protobuff multiplies things by a random vector that's baked in.
// Needs to be changed everytime we refresh the protobuf
// onnx multiplies things by a random vector that's baked in.
// Needs to be changed everytime we refresh the onnx file
std::vector<float> gold = {0.764314, 1.05549, -3.59435, -2.3556, -0.611802, -0.0784514};
EXPECT(migraphx::verify_range(result_vector, gold));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment