Commit 632f819e authored by Paul's avatar Paul
Browse files

Preserve outputs

parent 8004868b
......@@ -2,6 +2,8 @@
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
......@@ -10,6 +12,36 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template<class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if (pred(ins))
{
result.push_back(ins);
return;
}
for(auto input:ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
void preserve_output_layout(module& m)
{
std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
return ins->get_shape().lens().size() == 4;
});
for(auto output:outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(std::next(output), make_op("layout", {{"permutation", permutation}}), output);
m.replace_instruction(output, layout);
}
}
void transform_convolutions(module& m)
{
for(auto ins : iterator_for(m))
......@@ -30,6 +62,7 @@ void transform_convolutions(module& m)
void layout_nhwc::apply(module& m) const
{
preserve_output_layout(m);
transform_convolutions(m);
dead_code_elimination{}.apply(m);
eliminate_contiguous{"contiguous"}.apply(m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment