Commit db3f1478 authored by Paul's avatar Paul
Browse files

Check for adds beign used once

parent 05676328
...@@ -60,7 +60,8 @@ struct find_mul_add ...@@ -60,7 +60,8 @@ struct find_mul_add
match::either_arg(0, 1)( match::either_arg(0, 1)(
match::any().bind("x"), match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("y")), match::any_of(conv_const_weights(), match::is_constant()).bind("y")),
match::none_of(match::args(match::is_constant(), match::is_constant()))), match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once_recursive(4)),
match::is_constant().bind("a"))); match::is_constant().bind("a")));
} }
...@@ -137,7 +138,7 @@ struct find_double_add_lit_broadcast ...@@ -137,7 +138,7 @@ struct find_double_add_lit_broadcast
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(program& p) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 4; i++) for(int i = 0; i < 2; i++)
match::find_matches(p, match::find_matches(p,
match::skip_matches(match::is_unused(), match::is_constant()), match::skip_matches(match::is_unused(), match::is_constant()),
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
......
...@@ -46,8 +46,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -46,8 +46,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, // common_subexpression_elimination{},
//dead_code_elimination{}, // dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
......
...@@ -102,24 +102,6 @@ TEST_CASE(simplify_add3) ...@@ -102,24 +102,6 @@ TEST_CASE(simplify_add3)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(simplify_mul_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{});
auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
// TODO: Add test case // TODO: Add test case
void simplify_add4() void simplify_add4()
{ {
...@@ -150,4 +132,22 @@ void simplify_add4() ...@@ -150,4 +132,22 @@ void simplify_add4()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(simplify_mul_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{});
auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment