Commit 98b8dff1 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

workaround using two layout_nhwc passes

parent 7393cf1e
......@@ -15,6 +15,7 @@ struct module_pass_manager;
*/
struct layout_nhwc
{
bool skip_elim_contiguous = false;
std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& m) const;
};
......
......@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
return result;
}
void transform_convolutions(module& m)
void transform_convolutions(module& m, bool skip_elim_contiguous)
{
for(auto ins : iterator_for(m))
{
......@@ -56,11 +56,29 @@ void transform_convolutions(module& m)
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
if(skip_elim_contiguous)
{
// std::cout << "HERE" << std::endl;
for(auto i = 0; i < args.size(); i++)
{
// std::cout << args[i]->name() << std::endl;
if(args[i]->name() != "layout" and args[i]->get_shape().standard())
{
// std::cout << "HERE2" << std::endl;
args[i] = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), args[i]);
// m.debug_print(args);
}
}
}
else
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
auto c = conv;
if(not skip_elim_contiguous)
c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
......@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void layout_nhwc::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module());
transform_convolutions(mpm.get_module(), this->skip_elim_contiguous);
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
if(not this->skip_elim_contiguous)
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{});
......
......@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx,
if(solution_id == 0)
MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID");
auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
w_desc.get(),
args[1].implicit(),
x_desc.get(),
args[0].implicit(),
cd.get(),
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes(),
solution_id);
// auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
// w_desc.get(),
// args[1].implicit(),
// x_desc.get(),
// args[0].implicit(),
// cd.get(),
// y_desc.get(),
// args[3].implicit(),
// args[2].implicit(),
// args[2].get_shape().bytes(),
// solution_id);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: running convolution failed");
......
......@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
simplify_reshapes{},
propagate_constant{},
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment