Commit 8b1b6a5e authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'nhwc_workaround' of...

Merge branch 'nhwc_workaround' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround
parents 35dcd00d 01b5ff4e
...@@ -76,8 +76,6 @@ void auto_contiguous::apply(module& m) const ...@@ -76,8 +76,6 @@ void auto_contiguous::apply(module& m) const
{ {
continue; continue;
} }
if(ins->name() == "pointwise")
std::cout << "HERE" << std::endl;
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
} }
......
...@@ -81,7 +81,7 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -81,7 +81,7 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(not try_compute_shape(output, input_shapes, mods)) if(not try_compute_shape(output, input_shapes, output->module_inputs()))
{ {
return false; return false;
} }
......
...@@ -40,75 +40,51 @@ ...@@ -40,75 +40,51 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate> // template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred) // std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{ // {
std::vector<instruction_ref> result; // std::vector<instruction_ref> result;
fix([&](auto self, auto ins) { // fix([&](auto self, auto ins) {
if(pred(ins)) // if(pred(ins))
{ // {
result.push_back(ins); // result.push_back(ins);
return; // return;
} // }
for(auto input : ins->inputs()) // for(auto input : ins->inputs())
self(input); // self(input);
})(std::prev(m.end())); // })(std::prev(m.end()));
return result; // return result;
} // }
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs =
find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout_ins = m.insert_instruction( // std::unordered_set<instruction_ref> preserve_output_layout(module& m)
std::next(output), make_op("layout", {{"permutation", permutation}}), output); // {
// std::unordered_set<instruction_ref> result;
// std::vector<instruction_ref> outputs =
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
// for(auto output : outputs)
// {
// auto permutation = find_permutation(output->get_shape());
auto output1 = m.insert_instruction( // auto layout_ins = m.insert_instruction(
layout_ins, make_op("allocate", {{"shape", to_value(layout_ins->get_shape())}})); // std::next(output), make_op("layout", {{"permutation", permutation}}), output);
std::vector<instruction_ref> refs = layout_ins->inputs();
refs.push_back(output1);
auto layout = m.replace_instruction( // auto output1 = m.insert_instruction(
layout_ins, // layout_ins, make_op("allocate", {{"shape", to_value(layout_ins->get_shape())}}));
make_op("gpu::precompile_op", {{"op", to_value(layout_ins->get_operator())}}), // std::vector<instruction_ref> refs = layout_ins->inputs();
refs, // refs.push_back(output1);
layout_ins->module_inputs());
result.insert(layout); // auto layout = m.replace_instruction(
// m.debug_print(layout); // layout_ins,
} // make_op("gpu::precompile_op", {{"op", to_value(layout_ins->get_operator())}}),
return result; // refs,
} // layout_ins->module_inputs());
void remove_layout(module& m) // result.insert(layout);
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
auto in_shape = ins->inputs().front()->get_shape();
if(in_shape == ins->get_shape())
m.replace_instruction(ins, ins->inputs().front());
}
}
// std::vector<instruction_ref> find_convs(const module& m)
// {
// std::vector<instruction_ref> convs;
// for(auto ins : iterator_for(m))
// {
// if(ins->name() == "gpu::miopen_op")
// convs.push_back(ins);
// } // }
// return convs; // return result;
// } // }
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts) void remove_layout(module& m)
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
...@@ -120,18 +96,14 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_ ...@@ -120,18 +96,14 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
if(val["op"].at("name").to<std::string>() != "layout") if(val["op"].at("name").to<std::string>() != "layout")
{ {
// std::cout << val["op"].at("name").to<std::string>() << std::endl;
continue; continue;
} }
// m.debug_print(ins);
if(ins->get_shape() != ins->inputs().front()->get_shape()) if(ins->get_shape() != ins->inputs().front()->get_shape())
{ {
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() <<
// std::endl;
continue; continue;
} }
if(contains(output_layouts, ins)) // if(contains(output_layouts, ins))
continue; // continue;
m.replace_instruction(ins, ins->inputs().front()); m.replace_instruction(ins, ins->inputs().front());
} }
...@@ -139,10 +111,9 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_ ...@@ -139,10 +111,9 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void eliminate_layout::apply(module_pass_manager& mpm) const void eliminate_layout::apply(module_pass_manager& mpm) const
{ {
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module()); // std::unordered_set<instruction_ref> output_layouts =
remove_layout(mpm.get_module(), output_layouts); // preserve_output_layout(mpm.get_module());
// find_convs(mpm.get_module())); remove_layout(mpm.get_module());
// remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
} }
......
...@@ -38,9 +38,8 @@ struct module_pass_manager; ...@@ -38,9 +38,8 @@ struct module_pass_manager;
*/ */
struct MIGRAPHX_EXPORT layout_nhwc struct MIGRAPHX_EXPORT layout_nhwc
{ {
bool skip_elim_contiguous = false;
std::string name() const { return "layout_nhwc"; } std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -30,47 +30,43 @@ ...@@ -30,47 +30,43 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <stdexcept>
#include <system_error>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// template <class Predicate> template <class Predicate>
// std::vector<instruction_ref> find_lasts(const module& m, Predicate pred) std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
// { {
// std::vector<instruction_ref> result; std::vector<instruction_ref> result;
// fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
// if(pred(ins)) if(pred(ins))
// { {
// result.push_back(ins); result.push_back(ins);
// return; return;
// } }
// for(auto input : ins->inputs()) for(auto input : ins->inputs())
// self(input); self(input);
// })(std::prev(m.end())); })(std::prev(m.end()));
// return result; return result;
// } }
// std::unordered_set<instruction_ref> preserve_output_layout(module& m) void preserve_output_layout(module& m)
// { {
// std::unordered_set<instruction_ref> result; std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
// std::vector<instruction_ref> outputs = return ins->name() == "convolution" and ins->get_shape().lens().size() == 4;
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; }); });
// for(auto output : outputs) for(auto output : outputs)
// { {
// auto permutation = find_permutation(output->get_shape()); auto permutation = find_permutation(output->get_shape());
// auto layout = m.insert_instruction( auto layout = m.insert_instruction(
// std::next(output), make_op("layout", {{"permutation", permutation}}), output); std::next(output), make_op("layout", {{"permutation", permutation}}), output);
// result.insert(m.replace_instruction(output, layout)); m.replace_instruction(output, layout);
// } }
// return result; }
// }
void transform_convolutions(module& m, bool skip_elim_contiguous) void transform_convolutions(module& m)
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
...@@ -82,79 +78,21 @@ void transform_convolutions(module& m, bool skip_elim_contiguous) ...@@ -82,79 +78,21 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
if(v.at("group").to<int>() > 1) if(v.at("group").to<int>() > 1)
continue; continue;
auto args = ins->inputs(); auto args = ins->inputs();
if(skip_elim_contiguous) std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
{ return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
for(auto i = 0; i < args.size(); i++) });
{
if(args[i]->name() != "layout" and args[i]->get_shape().standard())
{
args[i] = m.insert_instruction(
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]);
}
}
}
else
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
return m.insert_instruction(
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto conv = m.insert_instruction(ins, ins->get_operator(), args);
// m.debug_print(conv); auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
// auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}),
// conv); m.debug_print(); if(not skip_elim_contiguous)
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
} }
} }
void insert_contiguous(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "reshape" and ins->name() != "pooling")
continue;
auto c = m.insert_instruction(ins, make_op("contiguous"), ins->inputs().front());
auto reshape = m.insert_instruction(ins, ins->get_operator(), c);
m.replace_instruction(ins, reshape);
}
// m.debug_print();
}
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "layout")
// continue;
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// continue;
// if(contains(output_layouts, ins))
// continue;
// m.replace_instruction(ins, ins->inputs().front());
// }
// }
void layout_nhwc::apply(module_pass_manager& mpm) const void layout_nhwc::apply(module_pass_manager& mpm) const
{ {
// std::unordered_set<instruction_ref> output_layouts = module& m = mpm.get_module();
// preserve_output_layout(mpm.get_module()); // preserve_output_layout(m);
// insert_contiguous(mpm.get_module()); transform_convolutions(m);
mpm.run_pass(dead_code_elimination{});
// mpm.get_module().debug_print();
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
// std::cout << "after layout" << std::endl;
// mpm.get_module().debug_print();
// if(not this->skip_elim_contiguous)
// mpm.run_pass(eliminate_contiguous{"contiguous"});
// mpm.run_pass(dead_code_elimination{});
// mpm.run_pass(auto_contiguous{});
// mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -782,6 +782,8 @@ struct find_contiguous_pointwise ...@@ -782,6 +782,8 @@ struct find_contiguous_pointwise
auto args = pw->inputs(); auto args = pw->inputs();
args.back() = alloc; args.back() = alloc;
if(ins->get_shape() != pw->get_shape())
return;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs()); m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
} }
}; };
...@@ -835,24 +837,23 @@ struct find_concat_pointwise ...@@ -835,24 +837,23 @@ struct find_concat_pointwise
auto op = concat->get_operator(); auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}}); op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}});
m.replace_instruction(ins, op, inputs, {pm}); m.replace_instruction(ins, op, inputs, {pm});
} }
}; };
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
// match::find_matches(m, find_contiguous_pointwise{}); match::find_matches(m, find_contiguous_pointwise{});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx}); match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, match::find_matches(m,
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_concat_pointwise{}, // find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
// match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -131,7 +131,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -131,7 +131,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{}, optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
optimize_module{}, optimize_module{},
...@@ -150,7 +149,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -150,7 +149,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
// dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
compile_miopen{&gctx}, compile_miopen{&gctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment