Commit 90967138 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

update to remove extra contiguous by requires_std_shape, update tests

parent e7e73c8c
...@@ -61,8 +61,14 @@ void auto_contiguous::apply(module& m) const ...@@ -61,8 +61,14 @@ void auto_contiguous::apply(module& m) const
{ {
if(contains({"layout", "contiguous", "@return", "@param", "@outline"}, ins->name())) if(contains({"layout", "contiguous", "@return", "@param", "@outline"}, ins->name()))
continue; continue;
auto outputs = ins->outputs();
// for last instruction that is NOT a return // for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last) if(outputs.empty() and ins != last)
continue;
if(not outputs.empty())
// if contiguous was already inserted, skip
if(std::all_of(outputs.begin(), outputs.end(), [](auto output) {
return output->name() == "contiguous";}))
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(s.dynamic()) if(s.dynamic())
...@@ -72,9 +78,9 @@ void auto_contiguous::apply(module& m) const ...@@ -72,9 +78,9 @@ void auto_contiguous::apply(module& m) const
if(s.standard() and ins->name() == "@literal") if(s.standard() and ins->name() == "@literal")
continue; continue;
if(s.scalar() and not contains(ins->name(), "broadcast")) if(s.scalar() and not contains(ins->name(), "broadcast"))
{
continue; continue;
}
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
} }
......
...@@ -179,7 +179,8 @@ TEST_CASE(standard_reshape_lazy) ...@@ -179,7 +179,8 @@ TEST_CASE(standard_reshape_lazy)
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = auto r =
m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca); m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r}); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
...@@ -201,9 +202,7 @@ TEST_CASE(standard_reshape) ...@@ -201,9 +202,7 @@ TEST_CASE(standard_reshape)
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data); auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
// extra contiguous coming from reshape logic which has "requires_std_shape" attribute auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
auto cb = m2.add_instruction(migraphx::make_op("contiguous"), ca);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), cb);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr}); m2.add_return({cr});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment