Commit 8b1b6a5e authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'nhwc_workaround' of...

Merge branch 'nhwc_workaround' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into nhwc_workaround
parents 35dcd00d 01b5ff4e
......@@ -76,8 +76,6 @@ void auto_contiguous::apply(module& m) const
{
continue;
}
if(ins->name() == "pointwise")
std::cout << "HERE" << std::endl;
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
}
......
......@@ -81,7 +81,7 @@ static bool try_compute_shape(instruction_ref ins,
return (arg == ins) ? new_shape : arg->get_shape();
});
if(not try_compute_shape(output, input_shapes, mods))
if(not try_compute_shape(output, input_shapes, output->module_inputs()))
{
return false;
}
......
......@@ -40,75 +40,51 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs =
find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
// template <class Predicate>
// std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
// {
// std::vector<instruction_ref> result;
// fix([&](auto self, auto ins) {
// if(pred(ins))
// {
// result.push_back(ins);
// return;
// }
// for(auto input : ins->inputs())
// self(input);
// })(std::prev(m.end()));
// return result;
// }
auto layout_ins = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
// std::unordered_set<instruction_ref> preserve_output_layout(module& m)
// {
// std::unordered_set<instruction_ref> result;
// std::vector<instruction_ref> outputs =
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
// for(auto output : outputs)
// {
// auto permutation = find_permutation(output->get_shape());
auto output1 = m.insert_instruction(
layout_ins, make_op("allocate", {{"shape", to_value(layout_ins->get_shape())}}));
std::vector<instruction_ref> refs = layout_ins->inputs();
refs.push_back(output1);
// auto layout_ins = m.insert_instruction(
// std::next(output), make_op("layout", {{"permutation", permutation}}), output);
auto layout = m.replace_instruction(
layout_ins,
make_op("gpu::precompile_op", {{"op", to_value(layout_ins->get_operator())}}),
refs,
layout_ins->module_inputs());
// auto output1 = m.insert_instruction(
// layout_ins, make_op("allocate", {{"shape", to_value(layout_ins->get_shape())}}));
// std::vector<instruction_ref> refs = layout_ins->inputs();
// refs.push_back(output1);
result.insert(layout);
// m.debug_print(layout);
}
return result;
}
// auto layout = m.replace_instruction(
// layout_ins,
// make_op("gpu::precompile_op", {{"op", to_value(layout_ins->get_operator())}}),
// refs,
// layout_ins->module_inputs());
void remove_layout(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
auto in_shape = ins->inputs().front()->get_shape();
if(in_shape == ins->get_shape())
m.replace_instruction(ins, ins->inputs().front());
}
}
// std::vector<instruction_ref> find_convs(const module& m)
// {
// std::vector<instruction_ref> convs;
// for(auto ins : iterator_for(m))
// {
// if(ins->name() == "gpu::miopen_op")
// convs.push_back(ins);
// result.insert(layout);
// }
// return convs;
// return result;
// }
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
void remove_layout(module& m)
{
for(auto ins : iterator_for(m))
{
......@@ -120,18 +96,14 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
if(val["op"].at("name").to<std::string>() != "layout")
{
// std::cout << val["op"].at("name").to<std::string>() << std::endl;
continue;
}
// m.debug_print(ins);
if(ins->get_shape() != ins->inputs().front()->get_shape())
{
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() <<
// std::endl;
continue;
}
if(contains(output_layouts, ins))
continue;
// if(contains(output_layouts, ins))
// continue;
m.replace_instruction(ins, ins->inputs().front());
}
......@@ -139,10 +111,9 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void eliminate_layout::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
remove_layout(mpm.get_module(), output_layouts);
// find_convs(mpm.get_module()));
// remove_layout(mpm.get_module());
// std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module());
remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
}
......
......@@ -38,9 +38,8 @@ struct module_pass_manager;
*/
struct MIGRAPHX_EXPORT layout_nhwc
{
bool skip_elim_contiguous = false;
std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& m) const;
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -30,47 +30,43 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <stdexcept>
#include <system_error>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// template <class Predicate>
// std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
// {
// std::vector<instruction_ref> result;
// fix([&](auto self, auto ins) {
// if(pred(ins))
// {
// result.push_back(ins);
// return;
// }
// for(auto input : ins->inputs())
// self(input);
// })(std::prev(m.end()));
// return result;
// }
template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
// std::unordered_set<instruction_ref> preserve_output_layout(module& m)
// {
// std::unordered_set<instruction_ref> result;
// std::vector<instruction_ref> outputs =
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
// for(auto output : outputs)
// {
// auto permutation = find_permutation(output->get_shape());
// auto layout = m.insert_instruction(
// std::next(output), make_op("layout", {{"permutation", permutation}}), output);
// result.insert(m.replace_instruction(output, layout));
// }
// return result;
// }
void preserve_output_layout(module& m)
{
std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
return ins->name() == "convolution" and ins->get_shape().lens().size() == 4;
});
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
m.replace_instruction(output, layout);
}
}
void transform_convolutions(module& m, bool skip_elim_contiguous)
void transform_convolutions(module& m)
{
for(auto ins : iterator_for(m))
{
......@@ -82,79 +78,21 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
if(skip_elim_contiguous)
{
for(auto i = 0; i < args.size(); i++)
{
if(args[i]->name() != "layout" and args[i]->get_shape().standard())
{
args[i] = m.insert_instruction(
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]);
}
}
}
else
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
return m.insert_instruction(
ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
// m.debug_print(conv);
// auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}),
// conv); m.debug_print(); if(not skip_elim_contiguous)
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
void insert_contiguous(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "reshape" and ins->name() != "pooling")
continue;
auto c = m.insert_instruction(ins, make_op("contiguous"), ins->inputs().front());
auto reshape = m.insert_instruction(ins, ins->get_operator(), c);
m.replace_instruction(ins, reshape);
}
// m.debug_print();
}
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "layout")
// continue;
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// continue;
// if(contains(output_layouts, ins))
// continue;
// m.replace_instruction(ins, ins->inputs().front());
// }
// }
void layout_nhwc::apply(module_pass_manager& mpm) const
{
// std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module());
// insert_contiguous(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
// mpm.get_module().debug_print();
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
module& m = mpm.get_module();
// preserve_output_layout(m);
transform_convolutions(m);
mpm.run_pass(dead_code_elimination{});
// std::cout << "after layout" << std::endl;
// mpm.get_module().debug_print();
// if(not this->skip_elim_contiguous)
// mpm.run_pass(eliminate_contiguous{"contiguous"});
// mpm.run_pass(dead_code_elimination{});
// mpm.run_pass(auto_contiguous{});
// mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -782,6 +782,8 @@ struct find_contiguous_pointwise
auto args = pw->inputs();
args.back() = alloc;
if(ins->get_shape() != pw->get_shape())
return;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
}
};
......@@ -835,24 +837,23 @@ struct find_concat_pointwise
auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}});
m.replace_instruction(ins, op, inputs, {pm});
}
};
void fuse_ops::apply(module& m) const
{
// match::find_matches(m, find_contiguous_pointwise{});
match::find_matches(m, find_contiguous_pointwise{});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m,
find_layernorm_pointwise{},
find_concat_pointwise{},
// find_concat_pointwise{},
find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{});
// match::find_matches(m, find_contiguous{});
match::find_matches(m, find_contiguous{});
}
} // namespace gpu
......
......@@ -131,7 +131,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
prefuse_ops{},
dead_code_elimination{},
optimize_module{},
......@@ -150,7 +149,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), eliminate_layout{}),
// dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
compile_miopen{&gctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment