"vscode:/vscode.git/clone" did not exist on "62ee1d5a3cb911d9e4bc0a49341794f3c2159742"
Commit 379ef733 authored by Paul's avatar Paul
Browse files

Add a pass to transform pooling to reduce_mean

parent aa7b76b5
......@@ -14,6 +14,7 @@ add_library(migraphx
eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
rewrite_pooling.cpp
env.cpp
generate.cpp
instruction.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_POOLING_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite pooling to reduce_mean
*/
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if (ins->name() != "pooling")
continue;
if (ins->get_shape().lens().size() != 4)
continue;
if (ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if (op.mode != "average")
continue;
if (op.padding[0] != 0 and op.padding[1] != 0)
continue;
if (op.stride[0] != 1 and op.stride[1] != 1)
continue;
if (s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
continue;
prog.replace_instruction(ins, op::reduce_mean{{2, 3}}, ins->inputs().front());
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
......@@ -45,6 +46,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{},
//common_subexpression_elimination{},
//dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment