Commit faddc14e authored by Khalique's avatar Khalique
Browse files

renamed to eliminate_pad, changed symmetric function

parent 4ce71f3a
...@@ -11,7 +11,7 @@ add_library(migraphx ...@@ -11,7 +11,7 @@ add_library(migraphx
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp eliminate_identity.cpp
pad_rewrite.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
env.cpp env.cpp
......
#include <migraphx/pad_rewrite.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void pad_rewrite::apply(program& p) const void eliminate_pad::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -28,7 +28,7 @@ void pad_rewrite::apply(program& p) const ...@@ -28,7 +28,7 @@ void pad_rewrite::apply(program& p) const
} }
template <class T> template <class T>
void pad_rewrite::update_op(T, void eliminate_pad::update_op(T,
const instruction_ref& input, const instruction_ref& input,
const instruction_ref& ins, const instruction_ref& ins,
program& p) const program& p) const
......
...@@ -11,9 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,9 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
/** /**
* Remove identity instructions. Currently when used as the last pass, it will * Remove identity instructions.
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
*/ */
struct eliminate_identity struct eliminate_identity
{ {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_REWRITE_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_REWRITE_HPP #define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -13,13 +13,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,13 +13,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
/** /**
* Remove identity instructions. Currently when used as the last pass, it will * Remove pads if they can be written as an
* preserve the semantics of previous program state, therefore dead code elimination * attribute to another op (im2col, convolution, pooling)
* should not be used afterwards.
*/ */
struct pad_rewrite struct eliminate_pad
{ {
std::string name() const { return "pad_rewrite"; } std::string name() const { return "eliminate_pad"; }
void apply(program& p) const; void apply(program& p) const;
template <class T> template <class T>
void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const; void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const;
......
...@@ -718,12 +718,7 @@ struct pad ...@@ -718,12 +718,7 @@ struct pad
bool symmetric() const bool symmetric() const
{ {
std::size_t num_dims = pads.size() / 2; std::size_t num_dims = pads.size() / 2;
for(std::size_t i = 0; i < num_dims; i++) return std::equal(pads.begin(), pads.begin() + num_dims, pads.begin() + num_dims, pads.end());
{
if(pads.at(i) != pads.at(i + num_dims))
return false;
}
return true;
} }
}; };
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/pad_rewrite.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
namespace migraphx { namespace migraphx {
...@@ -37,7 +37,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -37,7 +37,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
{ {
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{}, eliminate_identity{},
pad_rewrite{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -59,7 +59,7 @@ TEST_CASE(simple_test_end_dependency) ...@@ -59,7 +59,7 @@ TEST_CASE(simple_test_end_dependency)
p.add_instruction(sum_op{}, ans, three); p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans); p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{}); p.compile(eliminate_identity_target{});
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(!std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}); auto result = p.eval({});
......
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pad_rewrite.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <test.hpp> #include <test.hpp>
struct pad_rewrite_target struct eliminate_pad_target
{ {
std::string name() const { return "pad_rewrite"; } std::string name() const { return "eliminate_pad"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraphx::pad_rewrite{}, migraphx::dead_code_elimination{}}; return {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}};
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
...@@ -54,7 +54,7 @@ TEST_CASE(rewrite_test) ...@@ -54,7 +54,7 @@ TEST_CASE(rewrite_test)
auto l2 = p.add_instruction(migraphx::op::pooling{}, padded_img); auto l2 = p.add_instruction(migraphx::op::pooling{}, padded_img);
p.add_instruction(migraphx::op::identity{}, l0, l1, l2); p.add_instruction(migraphx::op::identity{}, l0, l1, l2);
p.compile(pad_rewrite_target{}); p.compile(eliminate_pad_target{});
EXPECT(std::none_of( EXPECT(std::none_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
} }
...@@ -73,7 +73,7 @@ TEST_CASE(rewrite_test_asymmetric) ...@@ -73,7 +73,7 @@ TEST_CASE(rewrite_test_asymmetric)
create_im2col(padded_img, channels, p); create_im2col(padded_img, channels, p);
p.compile(pad_rewrite_target{}); p.compile(eliminate_pad_target{});
EXPECT(std::any_of( EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment