"library/vscode:/vscode.git/clone" did not exist on "0dcb3496cf3e274386272e0a4430282f9ddf1169"
Commit 5c7bee3a authored by Paul's avatar Paul
Browse files

Rename batchnorm pass

parent cfd36b63
...@@ -12,7 +12,7 @@ add_library(migraphx ...@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_pad.cpp eliminate_pad.cpp
fwd_conv_batchnorm_rewrite.cpp rewrite_batchnorm.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
env.cpp env.cpp
generate.cpp generate.cpp
......
...@@ -13,9 +13,9 @@ struct program; ...@@ -13,9 +13,9 @@ struct program;
/** /**
* Rewrite batchnorm to a multiply and add. * Rewrite batchnorm to a multiply and add.
*/ */
struct fwd_conv_batchnorm_rewrite struct rewrite_batchnorm
{ {
std::string name() const { return "fwd_conv_batchnorm_rewrite"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void rewrite_batchnorm::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
...@@ -42,7 +42,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -42,7 +42,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, rewrite_batchnorm{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
......
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
p1.compile(migraphx::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{}); p2.compile(migraphx::cpu::target{});
...@@ -93,7 +93,7 @@ TEST_CASE(non_literal) ...@@ -93,7 +93,7 @@ TEST_CASE(non_literal)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
...@@ -121,7 +121,7 @@ TEST_CASE(as_literal) ...@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
...@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape) ...@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment