"scripts/wan/run_wan_t2v.sh" did not exist on "a50bcc535aaf63e209302ea89be5dc832980a402"
Commit 05e60c54 authored by Paul's avatar Paul
Browse files

Add nhwc pass

parent a4f8d30b
......@@ -31,6 +31,7 @@ add_library(migraphx
insert_pad.cpp
instruction.cpp
json.cpp
layout_nhwc.cpp
load_save.cpp
make_op.cpp
module.cpp
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Transform convolutions to nhwc
*/
struct layout_nhwc
{
std::string name() const { return "layout_nhwc"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void transform_convolutions(module& m)
{
for(auto ins : iterator_for(m))
{
if (ins->name() != "convolution")
continue;
if (ins->get_shape().lens().size() != 4)
continue;
auto args = ins->inputs();
std::transform(args.begin(), args.end(), args.begin(), [&](auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
void layout_nhwc::apply(module& m) const
{
transform_convolutions(m);
dead_code_elimination{}.apply(m);
eliminate_contiguous{"contiguous"}.apply(m);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -110,6 +110,7 @@ struct find_nop_reshapes
reshapes.insert("broadcast");
reshapes.insert("concat");
reshapes.insert("convert");
reshapes.insert("layout");
reshapes.insert("multibroadcast");
reshapes.insert("pad");
reshapes.insert("slice");
......
......@@ -11,6 +11,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
......@@ -65,6 +66,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
layout_nhwc{},
simplify_reshapes{},
simplify_algebra{},
auto_contiguous{},
simplify_reshapes{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment