"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "21ec113b35a50bb1a3de67ab582464b81418bab7"
Commit 9f25ffb7 authored by Khalique's avatar Khalique
Browse files

initial progress on pad_rewrite, fixes inceptionv3 onnx perf

parent 9f434a2b
...@@ -11,6 +11,7 @@ add_library(migraphx ...@@ -11,6 +11,7 @@ add_library(migraphx
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp eliminate_identity.cpp
pad_rewrite.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
env.cpp env.cpp
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <utility> #include <utility>
...@@ -30,7 +29,7 @@ void eliminate_identity::apply(program& p) const ...@@ -30,7 +29,7 @@ void eliminate_identity::apply(program& p) const
{ {
if(ins->name() == "identity") if(ins->name() == "identity")
{ {
const instruction_ref& identity_input = i->inputs().front(); const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1) if(identity_input->outputs().size() == 1)
{ {
p.move_instruction(identity_input, i); p.move_instruction(identity_input, i);
......
...@@ -714,6 +714,17 @@ struct pad ...@@ -714,6 +714,17 @@ struct pad
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
return s; return s;
} }
bool symmetric() const
{
std::size_t num_dims = pads.size()/2;
for(std::size_t i = 0; i < num_dims; i++)
{
if(pads.at(i) != pads.at(i+num_dims))
return false;
}
return true;
}
}; };
struct as_shape struct as_shape
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_REWRITE_HPP
#include <string>
#include <vector>
#include <array>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite pads to use attribute from other instructions instead.
*/
struct pad_rewrite
{
std::string name() const { return "pad_rewrite"; }
void apply(program& p) const;
template <class T>
void update_op(T, instruction_ref ins, instruction_ref output, program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void pad_rewrite::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() != "pad")
continue;
for (auto output : ins->outputs())
{
auto op_name = output->name();
if(op_name == "convolution")
update_op(op::convolution{}, ins, output, p);
else if(op_name == "im2col")
update_op(op::im2col{}, ins, output, p);
else if(op_name == "pooling")
update_op(op::pooling{}, ins, output, p);
}
}
}
template<class T>
void pad_rewrite::update_op(T, instruction_ref ins, instruction_ref output, program& p) const
{
auto pad_op = any_cast<op::pad>(ins->get_operator());
if(!pad_op.symmetric())
return;
std::vector<int64_t> pads = pad_op.pads;
assert(pads.size() == 8); // ensure input being padded has 4 dims (*2 for font and back padding)
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]),static_cast<size_t>(pads[3])};
T op = any_cast<T>(output->get_operator());
op.padding = new_pads;
std::vector<instruction_ref> new_inputs{output->inputs()};
new_inputs.front() = ins->inputs().front();
p.replace_instruction(output, op, new_inputs);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
namespace migraphx { namespace migraphx {
...@@ -34,6 +35,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -34,6 +35,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{}, eliminate_identity{},
pad_rewrite{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment