Commit 5c7bee3a authored by Paul's avatar Paul
Browse files

Rename batchnorm pass

parent cfd36b63
......@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_identity.cpp
eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_batchnorm.cpp
rewrite_rnn.cpp
env.cpp
generate.cpp
......
......@@ -13,9 +13,9 @@ struct program;
/**
* Rewrite batchnorm to a multiply and add.
*/
struct fwd_conv_batchnorm_rewrite
struct rewrite_batchnorm
{
std::string name() const { return "fwd_conv_batchnorm_rewrite"; }
std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const;
};
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp>
......@@ -12,7 +12,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const
void rewrite_batchnorm::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
......
......@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
......@@ -42,7 +42,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/op/convolution.hpp>
......@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
......@@ -93,7 +93,7 @@ TEST_CASE(non_literal)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
......@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
......@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
migraphx::rewrite_batchnorm opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment