"profiler/vscode:/vscode.git/clone" did not exist on "dc0bae32cb51ef9dde4d6b621e579d6b3f1ad067"
Commit ed49d0c4 authored by Paul's avatar Paul
Browse files

Simplify algebra

parent 300582fc
...@@ -11,6 +11,7 @@ add_library(migraph ...@@ -11,6 +11,7 @@ add_library(migraph
instruction.cpp instruction.cpp
program.cpp program.cpp
shape.cpp shape.cpp
simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
......
...@@ -264,6 +264,8 @@ auto any_of(Ts... ms) ...@@ -264,6 +264,8 @@ auto any_of(Ts... ms)
}); });
} }
MIGRAPH_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPH_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
......
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
namespace migraph {
struct program;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
#include <migraph/simplify_algebra.hpp>
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/matcher.hpp>
#include <migraph/literal.hpp>
namespace migraph {
struct find_add_lit_broadcast
{
auto lit_broadcast() const
{
return match::any_of(match::name("@literal"), match::name("broadcast"));
}
auto not_lit_broadcast() const
{
return match::none_of(match::name("@literal"), match::name("broadcast"));
}
auto add_lit_broadcast(std::string x, std::string y) const
{
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(x), not_lit_broadcast().bind(y)));
}
auto matcher() const
{
return match::name("add")(match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab;
if(a_ins->name() == "broadcast")
{
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum = p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum);
}
else
{
sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
}
auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
p.replace_instruction(ins, op::add{}, sumxy, sumab);
}
};
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
} // namespace migraph
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <migraph/auto_contiguous.hpp> #include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp> #include <migraph/simplify_reshapes.hpp>
#include <migraph/simplify_algebra.hpp>
#include <migraph/constant_propagate.hpp>
#include <migraph/eliminate_contiguous.hpp> #include <migraph/eliminate_contiguous.hpp>
#include <migraph/fwd_conv_batchnorm_rewrite.hpp> #include <migraph/fwd_conv_batchnorm_rewrite.hpp>
...@@ -25,6 +27,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -25,6 +27,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
constant_propagate{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment