Unverified Commit 94addd98 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add simplify rsqrt pass (#516)



* add simplify_rsqrt and test

* formatting

* add used_once check

* move used_once

* formatting

* add multi_use test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 233d4303
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
#include <migraphx/op/recip.hpp> #include <migraphx/op/recip.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -391,6 +392,23 @@ struct find_sub_const ...@@ -391,6 +392,23 @@ struct find_sub_const
} }
}; };
struct find_rsqrt
{
auto matcher() const
{
return match::name("recip")(match::args(
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
p.replace_instruction(ins, op::rsqrt{}, x_ins);
}
};
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(program& p) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
...@@ -405,6 +423,7 @@ void simplify_algebra::apply(program& p) const ...@@ -405,6 +423,7 @@ void simplify_algebra::apply(program& p) const
find_mul_add{}, find_mul_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{},
find_concat_unary{}, find_concat_unary{},
find_concat_binary{}); find_concat_binary{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(p);
......
...@@ -551,4 +551,38 @@ TEST_CASE(simplify_sub_const) ...@@ -551,4 +551,38 @@ TEST_CASE(simplify_sub_const)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(simplify_rsqrt)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = p1.add_instruction(migraphx::op::sqrt{}, x);
p1.add_instruction(migraphx::op::recip{}, sqrt);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
p2.add_instruction(migraphx::op::rsqrt{}, x);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_rsqrt_multi_use)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = p1.add_instruction(migraphx::op::sqrt{}, x);
auto add = p1.add_instruction(migraphx::op::add{}, sqrt, sqrt);
auto rsqrt = p1.add_instruction(migraphx::op::recip{}, sqrt);
p1.add_instruction(migraphx::op::add{}, rsqrt, add);
}
migraphx::program p2{p1};
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment