Commit 379ef733 authored by Paul's avatar Paul
Browse files

Add a pass to transform pooling to reduce_mean

parent aa7b76b5
...@@ -14,6 +14,7 @@ add_library(migraphx ...@@ -14,6 +14,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
rewrite_pooling.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite pooling to reduce_mean
*/
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if (ins->name() != "pooling")
continue;
if (ins->get_shape().lens().size() != 4)
continue;
if (ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if (op.mode != "average")
continue;
if (op.padding[0] != 0 and op.padding[1] != 0)
continue;
if (op.stride[0] != 1 and op.stride[1] != 1)
continue;
if (s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
continue;
prog.replace_instruction(ins, op::reduce_mean{{2, 3}}, ins->inputs().front());
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
...@@ -45,6 +46,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -45,6 +46,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, //common_subexpression_elimination{},
//dead_code_elimination{}, //dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment