Commit db4bc970 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the algorithm of eliminate contiguous.

parent ce8139e5
...@@ -9,19 +9,54 @@ ...@@ -9,19 +9,54 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args) bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{ {
try try
{ {
compute_shape(op, args); shape new_shape = ins->get_operator().compute_shape(inputs);
// If the output shape is a standard shape, no need to try its output
if (new_shape.standard())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means the last output shape
// is non-standard, then we cannot eliminate the contiguous
if (outputs.empty())
{
return false;
}
for (auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes;
for (auto arg : args)
{
input_shapes.push_back((arg == ins) ? new_shape : arg->get_shape());
}
if (!try_compute_shape(output, input_shapes))
{
return false;
}
}
} }
catch(...) catch(...)
{ {
return false; return false;
} }
return true; return true;
} }
bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
}
void eliminate_contiguous::apply(program& p) const void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -44,7 +79,7 @@ void eliminate_contiguous::apply(program& p) const ...@@ -44,7 +79,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args; auto new_args = args;
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins->get_operator(), new_args)) if(try_compute_shape(ins, new_args))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -795,7 +795,7 @@ struct cpu_binary ...@@ -795,7 +795,7 @@ struct cpu_binary
std::string name() const { return op.name(); } std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed()) if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed())
{ {
return inputs.at(0); return inputs.at(0);
} }
...@@ -811,7 +811,7 @@ struct cpu_binary ...@@ -811,7 +811,7 @@ struct cpu_binary
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
auto s1 = input1.get_shape(); auto s1 = input1.get_shape();
auto s2 = input2.get_shape(); auto s2 = input2.get_shape();
if(s1 == s2 and s1.packed() and s2.packed()) if(s1 == s2 and s1.packed())
{ {
std::transform( std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn()); input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
......
...@@ -70,7 +70,7 @@ struct binary_device : oper<Derived> ...@@ -70,7 +70,7 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed()) if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed())
{ {
return inputs.at(0); return inputs.at(0);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment