"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "fc9e0d9dab0c0cfc14241b990c1cce6ea7a750c0"
Commit 5e36c210 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

minor additional changes for improving the eliminate contiguous.

parent 6626861d
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs) static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
{ {
try try
{ {
...@@ -21,8 +21,10 @@ bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs) ...@@ -21,8 +21,10 @@ bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
} }
auto outputs = ins->outputs(); auto outputs = ins->outputs();
// If the current instruction has no output, it means the last output shape // If the current instruction has no output, it means it is the last
// is non-standard, then we cannot eliminate the contiguous // instruction and generates a non-standard output. But for unary
// and binary operators, we can still remove it and reshape the output
// to be standard since these operator can handle non-standard inputs
if(outputs.empty()) if(outputs.empty())
{ {
return true; return true;
...@@ -51,7 +53,7 @@ bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs) ...@@ -51,7 +53,7 @@ bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs)
return true; return true;
} }
bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args) static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args)
{ {
auto inputs = to_shapes(args); auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs); return try_compute_shape(ins, inputs);
......
...@@ -21,14 +21,7 @@ struct binary ...@@ -21,14 +21,7 @@ struct binary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
if(inputs.at(0) == inputs.at(1) and inputs.at(0).packed() and inputs.at(1).packed()) return {inputs.at(0).type(), inputs.at(0).lens()};
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
} }
}; };
......
...@@ -21,14 +21,7 @@ struct unary ...@@ -21,14 +21,7 @@ struct unary
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
if(inputs.front().packed()) return {inputs.at(0).type(), inputs.at(0).lens()};
{
return inputs.at(0);
}
else
{
return {inputs.at(0).type(), inputs.at(0).lens()};
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment