Commit 7ed60279 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

testing

parent 30c49503
...@@ -46,6 +46,7 @@ add_library(migraphx ...@@ -46,6 +46,7 @@ add_library(migraphx
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_data_type.cpp eliminate_data_type.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_layout.cpp
eliminate_pad.cpp eliminate_pad.cpp
env.cpp env.cpp
file_buffer.cpp file_buffer.cpp
......
...@@ -36,36 +36,36 @@ ...@@ -36,36 +36,36 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate> // template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred) // std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{ // {
std::vector<instruction_ref> result; // std::vector<instruction_ref> result;
fix([&](auto self, auto ins) { // fix([&](auto self, auto ins) {
if(pred(ins)) // if(pred(ins))
{ // {
result.push_back(ins); // result.push_back(ins);
return; // return;
} // }
for(auto input : ins->inputs()) // for(auto input : ins->inputs())
self(input); // self(input);
})(std::prev(m.end())); // })(std::prev(m.end()));
return result; // return result;
} // }
std::unordered_set<instruction_ref> preserve_output_layout(module& m) // std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{ // {
std::unordered_set<instruction_ref> result; // std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs = // std::vector<instruction_ref> outputs =
find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; }); // find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
for(auto output : outputs) // for(auto output : outputs)
{ // {
auto permutation = find_permutation(output->get_shape()); // auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction( // auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output); // std::next(output), make_op("layout", {{"permutation", permutation}}), output);
result.insert(m.replace_instruction(output, layout)); // result.insert(m.replace_instruction(output, layout));
} // }
return result; // return result;
} // }
void transform_convolutions(module& m, bool skip_elim_contiguous) void transform_convolutions(module& m, bool skip_elim_contiguous)
{ {
...@@ -108,30 +108,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous) ...@@ -108,30 +108,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
} }
} }
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts) // void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{ // {
for(auto ins : iterator_for(m)) // for(auto ins : iterator_for(m))
{ // {
if(ins->name() != "layout") // if(ins->name() != "layout")
continue; // continue;
if(ins->get_shape() != ins->inputs().front()->get_shape()) // if(ins->get_shape() != ins->inputs().front()->get_shape())
continue; // continue;
if(contains(output_layouts, ins)) // if(contains(output_layouts, ins))
continue; // continue;
m.replace_instruction(ins, ins->inputs().front()); // m.replace_instruction(ins, ins->inputs().front());
} // }
} // }
void layout_nhwc::apply(module_pass_manager& mpm) const void layout_nhwc::apply(module_pass_manager& mpm) const
{ {
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module()); // std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous); transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
if(not this->skip_elim_contiguous) if(not this->skip_elim_contiguous)
mpm.run_pass(eliminate_contiguous{"contiguous"}); mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts); // remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{}); // mpm.run_pass(dead_code_elimination{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_layout.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp> #include <migraphx/inline_module.hpp>
...@@ -125,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -125,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}), // enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
...@@ -134,6 +135,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,6 +135,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
fuse_mlir{&ctx}, fuse_mlir{&ctx},
dead_code_elimination{}, dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment