Commit 39cd4850 authored by Paul's avatar Paul
Browse files

Dont remove layouts in simplify_reshapes

parent cf3dbd36
...@@ -28,8 +28,9 @@ std::vector<instruction_ref> find_lasts(const module& m, Predicate pred) ...@@ -28,8 +28,9 @@ std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
return result; return result;
} }
void preserve_output_layout(module& m) std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{ {
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs = std::vector<instruction_ref> outputs =
find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; }); find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
for(auto output : outputs) for(auto output : outputs)
...@@ -37,8 +38,9 @@ void preserve_output_layout(module& m) ...@@ -37,8 +38,9 @@ void preserve_output_layout(module& m)
auto permutation = find_permutation(output->get_shape()); auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction( auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output); std::next(output), make_op("layout", {{"permutation", permutation}}), output);
m.replace_instruction(output, layout); result.insert(m.replace_instruction(output, layout));
} }
return result;
} }
void transform_convolutions(module& m) void transform_convolutions(module& m)
...@@ -59,12 +61,29 @@ void transform_convolutions(module& m) ...@@ -59,12 +61,29 @@ void transform_convolutions(module& m)
} }
} }
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
if (ins->get_shape() != ins->inputs().front()->get_shape())
continue;
if (contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
void layout_nhwc::apply(module& m) const void layout_nhwc::apply(module& m) const
{ {
preserve_output_layout(m); std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(m);
transform_convolutions(m); transform_convolutions(m);
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
eliminate_contiguous{"contiguous"}.apply(m); eliminate_contiguous{"contiguous"}.apply(m);
dead_code_elimination{}.apply(m);
remove_layout(m, output_layouts);
dead_code_elimination{}.apply(m);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -110,7 +110,6 @@ struct find_nop_reshapes ...@@ -110,7 +110,6 @@ struct find_nop_reshapes
reshapes.insert("broadcast"); reshapes.insert("broadcast");
reshapes.insert("concat"); reshapes.insert("concat");
reshapes.insert("convert"); reshapes.insert("convert");
reshapes.insert("layout");
reshapes.insert("multibroadcast"); reshapes.insert("multibroadcast");
reshapes.insert("pad"); reshapes.insert("pad");
reshapes.insert("slice"); reshapes.insert("slice");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment