#include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { void transform_convolutions(module& m) { for(auto ins : iterator_for(m)) { if(ins->name() != "convolution") continue; if(ins->get_shape().lens().size() != 4) continue; auto args = ins->inputs(); std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) { return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i); }); auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto c = m.insert_instruction(ins, make_op("contiguous"), conv); m.replace_instruction(ins, c); } } void layout_nhwc::apply(module& m) const { transform_convolutions(m); dead_code_elimination{}.apply(m); eliminate_contiguous{"contiguous"}.apply(m); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx