Commit c297ce5f authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fixes to handle constants

parent b6ca9b26
......@@ -43,7 +43,8 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{v.get_shape().type()});
migraphx::shape empty_constant(v.get_shape().type(), {1}, {0});
return info.add_literal(literal{empty_constant, {0}});
}
auto dim_size = info.attributes.at("value").t().dims_size();
......
......@@ -68,7 +68,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
s = migraphx::shape{type, {1}, {}};
}
else
{
......@@ -84,8 +84,16 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
if(s.elements() > 0)
{
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
}
else
{
std::vector<val_type> out_vec{val.front()};
l_out = literal(s, out_vec);
}
});
return info.add_literal(l_out);
......
......@@ -31,6 +31,7 @@
#include <migraphx/reduce_dims.hpp>
#include <algorithm>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
......@@ -92,7 +93,9 @@ struct parse_if : op_parser<parse_if>
auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must have compatible shapes ");
" then and else sub_graphs must have compatible shapes " +
to_string_range(then_out_shapes) + " vs " +
to_string_range(else_out_shapes));
};
if(then_out_shapes.size() != else_out_shapes.size())
......@@ -126,16 +129,14 @@ struct parse_if : op_parser<parse_if>
assert(not(then_lens.empty() and else_lens.empty()));
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
shape gen_shape(shape(out_shape.type(), {1}, {0}));
auto literal_ins = mdl->add_literal(literal(gen_shape, {0}));
auto unsqueeze_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
literal_ins);
auto scalar_ins =
mdl->insert_instruction(std::prev(mdl->end()),
make_op("scalar", {{"out_lens", out_shape.lens()}}),
std::prev(mdl->end()));
auto broad_ins = mdl->insert_instruction(
std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
unsqueeze_ins);
scalar_ins);
auto contig_out = mdl->insert_instruction(
std::prev(mdl->end()), make_op("contiguous"), broad_ins);
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
......@@ -144,11 +145,12 @@ struct parse_if : op_parser<parse_if>
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if(then_lens.empty())
if(then_out_shape.strides().empty())
{
then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
}
else if(else_lens.empty())
else if(else_out_shape.strides().empty())
{
else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
}
......@@ -183,6 +185,9 @@ struct parse_if : op_parser<parse_if>
}
}
then_mdl->debug_print();
else_mdl->debug_print();
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
auto out_s = if_ret->get_shape();
assert(out_s.type() == shape::tuple_type);
......
......@@ -761,7 +761,7 @@ TEST_CASE(constant_empty_scalar_int64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {0}});
auto prog = optimize_onnx("constant_empty_scalar_int64_test.onnx");
EXPECT(p == prog);
......@@ -781,8 +781,8 @@ TEST_CASE(const_of_shape_empty_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal(migraphx::shape::int32_type));
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
mm->add_literal(migraphx::literal(migraphx::shape::int32_type, {0}));
migraphx::shape s(migraphx::shape::int64_type, {1});
std::vector<int64_t> vec(s.elements(), 10);
mm->add_literal(migraphx::literal(s, vec));
......@@ -2425,17 +2425,18 @@ TEST_CASE(if_literal_test)
auto cond = mm->add_parameter("cond", cond_s);
migraphx::shape s{migraphx::shape::float_type, {5}};
migraphx::shape empty_const(migraphx::shape::float_type, {1}, {0});
auto* then_mod = p.create_module("If_1_if");
std::vector<float> data1 = {1, 2, 3, 4, 5};
auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
then_mod->add_literal({});
then_mod->add_literal({empty_const, {0}});
then_mod->add_return({l1});
auto* else_mod = p.create_module("If_1_else");
std::vector<float> data2 = {5, 4, 3, 2, 1};
auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
else_mod->add_literal({});
else_mod->add_literal({empty_const, {0}});
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
......@@ -2599,8 +2600,6 @@ TEST_CASE(if_then_empty_constant_test)
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction(
......@@ -2636,9 +2635,6 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
......@@ -2691,8 +2687,6 @@ TEST_CASE(if_else_empty_constant_test)
auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(s.type());
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
......@@ -2731,9 +2725,6 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(migraphx::shape::int64_type);
else_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = else_mod->add_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment