Commit 4eaf846f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add an attribute input_standard shape for the reshape operator

parent 9c5f6324
#include <algorithm>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -10,13 +11,22 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,13 +11,22 @@ inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const void auto_contiguous::apply(module& p) const
{ {
for(auto ins : iterator_for(p)) std::string key = "standard_input_shape";
for(auto ins : reverse_iterator_for(p))
{ {
shape s = ins->get_shape(); auto&& attr = ins->get_operator().attributes();
if(not s.standard() and s.elements() != 0) if((attr.contains(key) and attr.at(key).to<bool>()))
{ {
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto args = ins->inputs();
p.replace_instruction(ins, c); auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
return p.replace_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
p.replace_instruction(ins, ins->get_operator(), new_args);
}
} }
} }
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp> #include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -26,6 +27,11 @@ struct reshape ...@@ -26,6 +27,11 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const
{
return {{"standard_input_shape", true}};
}
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -70,7 +70,7 @@ static literal from_repeated(shape::type_t t, const T& r) ...@@ -70,7 +70,7 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{ {
if(ins->name() == "contiguous") if(ins->get_shape().standard())
{ {
return ins; return ins;
} }
......
...@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{},
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment