Commit 7ed60279 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

testing

parent 30c49503
......@@ -46,6 +46,7 @@ add_library(migraphx
eliminate_contiguous.cpp
eliminate_data_type.cpp
eliminate_identity.cpp
eliminate_layout.cpp
eliminate_pad.cpp
env.cpp
file_buffer.cpp
......
......@@ -36,36 +36,36 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
// template <class Predicate>
// std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
// {
// std::vector<instruction_ref> result;
// fix([&](auto self, auto ins) {
// if(pred(ins))
// {
// result.push_back(ins);
// return;
// }
// for(auto input : ins->inputs())
// self(input);
// })(std::prev(m.end()));
// return result;
// }
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs =
find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
result.insert(m.replace_instruction(output, layout));
}
return result;
}
// std::unordered_set<instruction_ref> preserve_output_layout(module& m)
// {
// std::unordered_set<instruction_ref> result;
// std::vector<instruction_ref> outputs =
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
// for(auto output : outputs)
// {
// auto permutation = find_permutation(output->get_shape());
// auto layout = m.insert_instruction(
// std::next(output), make_op("layout", {{"permutation", permutation}}), output);
// result.insert(m.replace_instruction(output, layout));
// }
// return result;
// }
void transform_convolutions(module& m, bool skip_elim_contiguous)
{
......@@ -108,30 +108,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
}
}
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
if(ins->get_shape() != ins->inputs().front()->get_shape())
continue;
if(contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "layout")
// continue;
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// continue;
// if(contains(output_layouts, ins))
// continue;
// m.replace_instruction(ins, ins->inputs().front());
// }
// }
void layout_nhwc::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
// std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{});
if(not this->skip_elim_contiguous)
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -31,6 +31,7 @@
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_layout.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
......@@ -125,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
// enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
simplify_reshapes{},
propagate_constant{},
dead_code_elimination{},
......@@ -134,6 +135,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment