Commit 98b8dff1 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

workaround using two layout_nhwc passes

parent 7393cf1e
...@@ -15,6 +15,7 @@ struct module_pass_manager; ...@@ -15,6 +15,7 @@ struct module_pass_manager;
*/ */
struct layout_nhwc struct layout_nhwc
{ {
bool skip_elim_contiguous = false;
std::string name() const { return "layout_nhwc"; } std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& m) const; void apply(module_pass_manager& m) const;
}; };
......
...@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m) ...@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
return result; return result;
} }
void transform_convolutions(module& m) void transform_convolutions(module& m, bool skip_elim_contiguous)
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
...@@ -56,11 +56,29 @@ void transform_convolutions(module& m) ...@@ -56,11 +56,29 @@ void transform_convolutions(module& m)
if(v.at("group").to<int>() > 1) if(v.at("group").to<int>() > 1)
continue; continue;
auto args = ins->inputs(); auto args = ins->inputs();
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) { if(skip_elim_contiguous)
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i); {
}); // std::cout << "HERE" << std::endl;
for(auto i = 0; i < args.size(); i++)
{
// std::cout << args[i]->name() << std::endl;
if(args[i]->name() != "layout" and args[i]->get_shape().standard())
{
// std::cout << "HERE2" << std::endl;
args[i] = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]);
// m.debug_print(args);
}
}
}
else
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv); auto c = conv;
if(not skip_elim_contiguous)
c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
} }
} }
...@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_ ...@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void layout_nhwc::apply(module_pass_manager& mpm) const void layout_nhwc::apply(module_pass_manager& mpm) const
{ {
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module()); std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module()); transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"}); if(not this->skip_elim_contiguous)
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts); remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
......
...@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx, ...@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx,
if(solution_id == 0) if(solution_id == 0)
MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID"); MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID");
auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(), // auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
w_desc.get(), // w_desc.get(),
args[1].implicit(), // args[1].implicit(),
x_desc.get(), // x_desc.get(),
args[0].implicit(), // args[0].implicit(),
cd.get(), // cd.get(),
y_desc.get(), // y_desc.get(),
args[3].implicit(), // args[3].implicit(),
args[2].implicit(), // args[2].implicit(),
args[2].get_shape().bytes(), // args[2].get_shape().bytes(),
solution_id); // solution_id);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: running convolution failed"); MIGRAPHX_THROW("MIOpen Convolution: running convolution failed");
......
...@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{}, prefuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment