Unverified Commit 98dfdf15 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Keep std shape (#1059)

This PR is the resolve two problems in the issue#999, i.e., non_standard_shape input to reshape and reduce_mean.
Three fixes:

Any operator that has a standard shape requirement will add a contiguous input for its input.
Eliminate_contiguous, when computing whether a contiguous can be removed, we should use all the updated args, not just the one that is being checked.
In two optimization in the simplify_reshape, we remove the contiguous in the reshaper name list, since eliminate_contiguous will remove the contiguous if it can be removed.
the solution is add an attribute to the operator that requires standard input shape, then in the auto_contiguous pass, add a contiguous to every input of such operators.
parent 9b829528
...@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const void auto_contiguous::apply(module& p) const
{ {
std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(p))
{
auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false)))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return p.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
p.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
......
...@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const ...@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
continue; continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args;
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() == op_name)
{ {
auto new_args = args; auto prev = arg->inputs().front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, ins->module_inputs())) if(try_compute_shape(ins, new_args, mod_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp> #include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -26,6 +27,8 @@ struct reshape ...@@ -26,6 +27,8 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r) ...@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{ {
if(ins->get_shape().standard()) auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{ {
return ins; return add_instruction(make_op("contiguous"), ins);
} }
return add_instruction(make_op("contiguous"), ins); return ins;
} }
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args, instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
......
...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather) ...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(standard_reshape)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(dead_instruction)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment