Commit faddc14e authored by Khalique's avatar Khalique
Browse files

renamed to eliminate_pad, changed symmetric function

parent 4ce71f3a
......@@ -11,7 +11,7 @@ add_library(migraphx
eliminate_contiguous.cpp
eliminate_concat.cpp
eliminate_identity.cpp
pad_rewrite.cpp
eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
env.cpp
......
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
......@@ -8,7 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void pad_rewrite::apply(program& p) const
void eliminate_pad::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
......@@ -28,7 +28,7 @@ void pad_rewrite::apply(program& p) const
}
template <class T>
void pad_rewrite::update_op(T,
void eliminate_pad::update_op(T,
const instruction_ref& input,
const instruction_ref& ins,
program& p) const
......
......@@ -11,9 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove identity instructions. Currently when used as the last pass, it will
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
* Remove identity instructions.
*/
struct eliminate_identity
{
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_REWRITE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_PAD_HPP
#include <string>
#include <vector>
......@@ -13,13 +13,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove identity instructions. Currently when used as the last pass, it will
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
* Remove pads if they can be written as an
* attribute to another op (im2col, convolution, pooling)
*/
struct pad_rewrite
struct eliminate_pad
{
std::string name() const { return "pad_rewrite"; }
std::string name() const { return "eliminate_pad"; }
void apply(program& p) const;
template <class T>
void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const;
......
......@@ -718,12 +718,7 @@ struct pad
bool symmetric() const
{
std::size_t num_dims = pads.size() / 2;
for(std::size_t i = 0; i < num_dims; i++)
{
if(pads.at(i) != pads.at(i + num_dims))
return false;
}
return true;
return std::equal(pads.begin(), pads.begin() + num_dims, pads.begin() + num_dims, pads.end());
}
};
......
......@@ -20,7 +20,7 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
namespace migraphx {
......@@ -37,7 +37,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
{
dead_code_elimination{},
eliminate_identity{},
pad_rewrite{},
eliminate_pad{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
......
......@@ -59,7 +59,7 @@ TEST_CASE(simple_test_end_dependency)
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
EXPECT(!std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct pad_rewrite_target
struct eliminate_pad_target
{
std::string name() const { return "pad_rewrite"; }
std::string name() const { return "eliminate_pad"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::pad_rewrite{}, migraphx::dead_code_elimination{}};
return {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
......@@ -54,7 +54,7 @@ TEST_CASE(rewrite_test)
auto l2 = p.add_instruction(migraphx::op::pooling{}, padded_img);
p.add_instruction(migraphx::op::identity{}, l0, l1, l2);
p.compile(pad_rewrite_target{});
p.compile(eliminate_pad_target{});
EXPECT(std::none_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
......@@ -73,7 +73,7 @@ TEST_CASE(rewrite_test_asymmetric)
create_im2col(padded_img, channels, p);
p.compile(pad_rewrite_target{});
p.compile(eliminate_pad_target{});
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment