Unverified Commit 3c457a3c authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

simplify div and sub by constants (#470)



* fix pad calc

* progress on div

* formatting

* continue work on pass

* continue testing

* formatting

* add recip and sub matcher

* formatting

* add tests

* formatting

* fix review comments

* remove unnecessary header

* remove headers
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8fa33f1a
#ifndef MIGRAPHX_GUARD_OPERATORS_RECIP_HPP
#define MIGRAPHX_GUARD_OPERATORS_RECIP_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct recip : unary<recip>
{
auto apply() const
{
return [](auto x) { return 1 / x; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -56,6 +56,7 @@
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_min.hpp>
......
......@@ -7,6 +7,8 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/recip.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
......@@ -346,6 +348,46 @@ struct find_add_convs
}
};
struct find_div_const
{
auto matcher() const
{
return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), op::recip{}, c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, op::mul{}, args.front(), recip);
}
};
struct find_sub_const
{
auto matcher() const
{
return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), op::neg{}, c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, op::add{}, args.front(), neg);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
......@@ -358,6 +400,8 @@ void simplify_algebra::apply(program& p) const
find_add_convs{},
find_mul_conv{},
find_mul_add{},
find_div_const{},
find_sub_const{},
find_concat_unary{},
find_concat_binary{});
dead_code_elimination{}.apply(p);
......
......@@ -491,4 +491,44 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
EXPECT(p1 == p2);
}
TEST_CASE(simplify_div_const)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p1.add_literal(2);
p1.add_instruction(migraphx::op::div{}, x, two);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p2.add_literal(2);
auto recip = p2.insert_instruction(std::next(two), migraphx::op::recip{}, two);
p2.add_instruction(migraphx::op::mul{}, x, recip);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_sub_const)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p1.add_literal(2);
p1.add_instruction(migraphx::op::sub{}, x, two);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p2.add_literal(2);
auto neg = p2.insert_instruction(std::next(two), migraphx::op::neg{}, two);
p2.add_instruction(migraphx::op::add{}, x, neg);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment