"src/targets/gpu/lowering.cpp" did not exist on "8e0fff81ab2707932903aca276eec2723a88c0cd"
Commit db4bc970 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the algorithm of eliminate contiguous.

parent ce8139e5
......@@ -9,19 +9,54 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{
try
{
compute_shape(op, args);
shape new_shape = ins->get_operator().compute_shape(inputs);
// If the output shape is a standard shape, no need to try its output
if (new_shape.standard())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means the last output shape
// is non-standard, then we cannot eliminate the contiguous
if (outputs.empty())
{
return false;
}
for (auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes;
for (auto arg : args)
{
input_shapes.push_back((arg == ins) ? new_shape : arg->get_shape());
}
if (!try_compute_shape(output, input_shapes))
{
return false;
}
}
}
catch(...)
{
return false;
}
return true;
}
bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs);
}
void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
......@@ -44,7 +79,7 @@ void eliminate_contiguous::apply(program& p) const
auto new_args = args;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins->get_operator(), new_args))
if(try_compute_shape(ins, new_args))
{
instruction::replace_argument(ins, arg, prev);
}
......
......@@ -795,7 +795,7 @@ struct cpu_binary
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed())
{
return inputs.at(0);
}
......@@ -811,7 +811,7 @@ struct cpu_binary
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
auto s1 = input1.get_shape();
auto s2 = input2.get_shape();
if(s1 == s2 and s1.packed() and s2.packed())
if(s1 == s2 and s1.packed())
{
std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
......
......@@ -70,7 +70,7 @@ struct binary_device : oper<Derived>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed())
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed())
{
return inputs.at(0);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment