Commit c4658e53 authored by Paul's avatar Paul
Browse files

Use module_pass_manager

parent c055fda8
......@@ -8,7 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct module_pass_manager;
/**
* Transform convolutions to nhwc
......@@ -16,7 +16,7 @@ struct module;
struct layout_nhwc
{
std::string name() const { return "layout_nhwc"; }
void apply(module& m) const;
void apply(module_pass_manager& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,6 +8,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -78,15 +79,15 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
}
}
void layout_nhwc::apply(module& m) const
void layout_nhwc::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(m);
transform_convolutions(m);
dead_code_elimination{}.apply(m);
eliminate_contiguous{"contiguous"}.apply(m);
dead_code_elimination{}.apply(m);
remove_layout(m, output_layouts);
dead_code_elimination{}.apply(m);
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment