Commit 0c212157 authored by Paul's avatar Paul
Browse files

Add pass to simplify mul conv

parent 3987066f
......@@ -369,6 +369,11 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins)
{
return ins->can_eval();
}
template <class... Ms>
auto skip_output(Ms... ms)
{
......
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_add_lit_broadcast
auto lit_broadcast()
{
auto lit_broadcast() const
{
return match::any_of(match::name("@literal"), match::name("broadcast"));
}
auto not_lit_broadcast() const
return match::any_of(match::name("@literal"), match::name("broadcast"));
}
auto not_lit_broadcast()
{
return match::none_of(match::name("@literal"), match::name("broadcast"));
}
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
return match::name(op)(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)),
not_lit_broadcast().bind(std::move(y))));
}
struct find_mul_conv
{
auto matcher() const
{
return match::none_of(match::name("@literal"), match::name("broadcast"));
return match::name("mul")(
match::either_arg(0, 1)(match::name("conv")(match::used_once(), match::args(match::any(), match::is_constant().bind("w"))).bind("conv"), match::name("broadcast").bind("a")));
}
auto add_lit_broadcast(std::string x, std::string y) const
void apply(program& p, match::matcher_result r) const
{
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)),
not_lit_broadcast().bind(std::move(y))));
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if (broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
auto new_conv = p.insert_instruction(ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y")));
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(program& p, match::matcher_result r) const
......@@ -59,7 +88,7 @@ struct find_add_lit_broadcast
}
};
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -99,6 +102,21 @@ TEST_CASE(simplify_add3)
EXPECT(p1 == p2);
}
TEST_CASE(simplify_mul_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1},{2, 2},{1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{});
EXPECT(conv->outputs().front()->name() != "mul");
}
// TODO: Add test case
void simplify_add4()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment