"doc/vscode:/vscode.git/clone" did not exist on "d5bdfed0890eba7764691690a4aba6902e6fa02f"
Commit c297ce5f authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Fixes to handle constants

parent b6ca9b26
...@@ -43,7 +43,8 @@ struct parse_constant : op_parser<parse_constant> ...@@ -43,7 +43,8 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); migraphx::shape empty_constant(v.get_shape().type(), {1}, {0});
return info.add_literal(literal{empty_constant, {0}});
} }
auto dim_size = info.attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
......
...@@ -68,7 +68,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -68,7 +68,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
// empty input tensor, output is a scalar // empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0) if(args[0]->get_shape().elements() == 0)
{ {
s = migraphx::shape{type, {1}, {0}}; s = migraphx::shape{type, {1}, {}};
} }
else else
{ {
...@@ -84,8 +84,16 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -84,8 +84,16 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val.visit([&](auto val) { l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>; using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element // l_val contains only one element
if(s.elements() > 0)
{
std::vector<val_type> out_vec(s.elements(), val.front()); std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec); l_out = literal(s, out_vec);
}
else
{
std::vector<val_type> out_vec{val.front()};
l_out = literal(s, out_vec);
}
}); });
return info.add_literal(l_out); return info.add_literal(l_out);
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <algorithm> #include <algorithm>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
...@@ -92,7 +93,9 @@ struct parse_if : op_parser<parse_if> ...@@ -92,7 +93,9 @@ struct parse_if : op_parser<parse_if>
auto throw_shapes = [&]() { auto throw_shapes = [&]() {
MIGRAPHX_THROW("PARSE_IF: " + info.name + MIGRAPHX_THROW("PARSE_IF: " + info.name +
" then and else sub_graphs must have compatible shapes "); " then and else sub_graphs must have compatible shapes " +
to_string_range(then_out_shapes) + " vs " +
to_string_range(else_out_shapes));
}; };
if(then_out_shapes.size() != else_out_shapes.size()) if(then_out_shapes.size() != else_out_shapes.size())
...@@ -126,16 +129,14 @@ struct parse_if : op_parser<parse_if> ...@@ -126,16 +129,14 @@ struct parse_if : op_parser<parse_if>
assert(not(then_lens.empty() and else_lens.empty())); assert(not(then_lens.empty() and else_lens.empty()));
auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) { auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
shape gen_shape(shape(out_shape.type(), {1}, {0})); auto scalar_ins =
auto literal_ins = mdl->add_literal(literal(gen_shape, {0})); mdl->insert_instruction(std::prev(mdl->end()),
auto unsqueeze_ins = mdl->insert_instruction( make_op("scalar", {{"out_lens", out_shape.lens()}}),
std::prev(mdl->end()), std::prev(mdl->end()));
make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
literal_ins);
auto broad_ins = mdl->insert_instruction( auto broad_ins = mdl->insert_instruction(
std::prev(mdl->end()), std::prev(mdl->end()),
make_op("multibroadcast", {{"out_lens", out_shape.lens()}}), make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
unsqueeze_ins); scalar_ins);
auto contig_out = mdl->insert_instruction( auto contig_out = mdl->insert_instruction(
std::prev(mdl->end()), make_op("contiguous"), broad_ins); std::prev(mdl->end()), make_op("contiguous"), broad_ins);
mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out); mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
...@@ -144,11 +145,12 @@ struct parse_if : op_parser<parse_if> ...@@ -144,11 +145,12 @@ struct parse_if : op_parser<parse_if>
// Handle one empty branch by setting output identical to the other // Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks // need to update the then_shape before we do further checks
if(then_lens.empty())
if(then_out_shape.strides().empty())
{ {
then_lens = handle_empty_branch(then_mdl, i, else_out_shape); then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
} }
else if(else_lens.empty()) else if(else_out_shape.strides().empty())
{ {
else_lens = handle_empty_branch(else_mdl, i, then_out_shape); else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
} }
...@@ -183,6 +185,9 @@ struct parse_if : op_parser<parse_if> ...@@ -183,6 +185,9 @@ struct parse_if : op_parser<parse_if>
} }
} }
then_mdl->debug_print();
else_mdl->debug_print();
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
auto out_s = if_ret->get_shape(); auto out_s = if_ret->get_shape();
assert(out_s.type() == shape::tuple_type); assert(out_s.type() == shape::tuple_type);
......
...@@ -761,7 +761,7 @@ TEST_CASE(constant_empty_scalar_int64_test) ...@@ -761,7 +761,7 @@ TEST_CASE(constant_empty_scalar_int64_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape::int64_type}); mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {0}});
auto prog = optimize_onnx("constant_empty_scalar_int64_test.onnx"); auto prog = optimize_onnx("constant_empty_scalar_int64_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -781,8 +781,8 @@ TEST_CASE(const_of_shape_empty_input_test) ...@@ -781,8 +781,8 @@ TEST_CASE(const_of_shape_empty_input_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal(migraphx::shape::int32_type)); mm->add_literal(migraphx::literal(migraphx::shape::int32_type, {0}));
migraphx::shape s(migraphx::shape::int64_type, {1}, {0}); migraphx::shape s(migraphx::shape::int64_type, {1});
std::vector<int64_t> vec(s.elements(), 10); std::vector<int64_t> vec(s.elements(), 10);
mm->add_literal(migraphx::literal(s, vec)); mm->add_literal(migraphx::literal(s, vec));
...@@ -2425,17 +2425,18 @@ TEST_CASE(if_literal_test) ...@@ -2425,17 +2425,18 @@ TEST_CASE(if_literal_test)
auto cond = mm->add_parameter("cond", cond_s); auto cond = mm->add_parameter("cond", cond_s);
migraphx::shape s{migraphx::shape::float_type, {5}}; migraphx::shape s{migraphx::shape::float_type, {5}};
migraphx::shape empty_const(migraphx::shape::float_type, {1}, {0});
auto* then_mod = p.create_module("If_1_if"); auto* then_mod = p.create_module("If_1_if");
std::vector<float> data1 = {1, 2, 3, 4, 5}; std::vector<float> data1 = {1, 2, 3, 4, 5};
auto l1 = then_mod->add_literal(migraphx::literal(s, data1)); auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
then_mod->add_literal({}); then_mod->add_literal({empty_const, {0}});
then_mod->add_return({l1}); then_mod->add_return({l1});
auto* else_mod = p.create_module("If_1_else"); auto* else_mod = p.create_module("If_1_else");
std::vector<float> data2 = {5, 4, 3, 2, 1}; std::vector<float> data2 = {5, 4, 3, 2, 1};
auto l2 = else_mod->add_literal(migraphx::literal(s, data2)); auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
else_mod->add_literal({}); else_mod->add_literal({empty_const, {0}});
else_mod->add_return({l2}); else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
...@@ -2599,8 +2600,6 @@ TEST_CASE(if_then_empty_constant_test) ...@@ -2599,8 +2600,6 @@ TEST_CASE(if_then_empty_constant_test)
auto* then_mod = p.create_module("If_4_if"); auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0})); migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0})); auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction( auto unsqueeze_ins = then_mod->add_instruction(
...@@ -2636,9 +2635,6 @@ TEST_CASE(if_then_empty_constant_multi_output_test) ...@@ -2636,9 +2635,6 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
auto* then_mod = p.create_module("If_4_if"); auto* then_mod = p.create_module("If_4_if");
then_mod->add_literal(migraphx::shape::int64_type);
then_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0})); migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0})); auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
...@@ -2691,8 +2687,6 @@ TEST_CASE(if_else_empty_constant_test) ...@@ -2691,8 +2687,6 @@ TEST_CASE(if_else_empty_constant_test)
auto* else_mod = p.create_module("If_4_else"); auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(s.type());
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0})); migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0})); auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
...@@ -2731,9 +2725,6 @@ TEST_CASE(if_else_empty_constant_multi_output_test) ...@@ -2731,9 +2725,6 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
auto* else_mod = p.create_module("If_4_else"); auto* else_mod = p.create_module("If_4_else");
else_mod->add_literal(migraphx::shape::int64_type);
else_mod->add_literal(migraphx::shape::int64_type);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0})); migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0})); auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = else_mod->add_instruction( auto unsqueeze_ins = else_mod->add_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment