Unverified Commit ba328767 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Add horizontal fusion for pointwise operators (#447)



* Horizantal fusion of unary and binary ops

* Formatting

* Fix bugs in matcher

* Add tests

* Formatting

* Add the op generically

* Formatting

* Rename test for more detail
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 532513cd
...@@ -27,6 +27,15 @@ auto conv_const_weights() ...@@ -27,6 +27,15 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w"))); match::args(match::any(), match::is_constant().bind("w")));
} }
MIGRAPHX_PRED_MATCHER(args_has_same_ops, instruction_ref ins)
{
if(ins->inputs().empty())
return true;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto j) {
return j->get_operator() == ins->inputs().front()->get_operator();
});
}
struct find_mul_conv struct find_mul_conv
{ {
auto matcher() const auto matcher() const
...@@ -167,6 +176,73 @@ struct find_inner_broadcast ...@@ -167,6 +176,73 @@ struct find_inner_broadcast
} }
}; };
struct find_concat_unary
{
auto matcher() const
{
return match::name("concat")(args_has_same_ops(),
match::arg(0)(match::nargs(1),
match::name("relu", "broadcast").bind("x"),
match::used_once()));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto op = x->get_operator();
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
// Adjust broadcast lens
if(op.name() == "broadcast")
{
auto b = any_cast<op::broadcast>(op);
if(b.axis != axis)
return;
b.broadcast_lens = ins->get_shape().lens();
op = b;
axis = 0;
}
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
return i->inputs().front();
});
auto concat = p.insert_instruction(ins, op::concat{axis}, inputs);
p.replace_instruction(ins, op, concat);
}
};
struct find_concat_binary
{
auto matcher() const
{
return match::name("concat")(args_has_same_ops(),
match::arg(0)(match::nargs(2),
match::name("add", "multiply").bind("x"),
match::used_once()));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto op = x->get_operator();
auto concat_op = ins->get_operator();
auto xinputs = ins->inputs();
std::transform(xinputs.begin(), xinputs.end(), xinputs.begin(), [&](auto i) {
return i->inputs().front();
});
auto yinputs = ins->inputs();
std::transform(yinputs.begin(), yinputs.end(), yinputs.begin(), [&](auto i) {
return i->inputs().back();
});
auto xconcat = p.insert_instruction(ins, concat_op, xinputs);
auto yconcat = p.insert_instruction(ins, concat_op, yinputs);
p.replace_instruction(ins, op, xconcat, yconcat);
}
};
bool axis_equal(const std::vector<std::size_t>& x, bool axis_equal(const std::vector<std::size_t>& x,
const std::vector<std::size_t>& y, const std::vector<std::size_t>& y,
std::size_t axis) std::size_t axis)
...@@ -281,7 +357,9 @@ void simplify_algebra::apply(program& p) const ...@@ -281,7 +357,9 @@ void simplify_algebra::apply(program& p) const
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
find_mul_conv{}, find_mul_conv{},
find_mul_add{}); find_mul_add{},
find_concat_unary{},
find_concat_binary{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(p);
} }
} }
......
...@@ -381,4 +381,114 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2) ...@@ -381,4 +381,114 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
} }
TEST_CASE(simplify_concat_add_relu)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal({s, {1}});
auto two = p1.add_literal({s, {2}});
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{0}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal({s, {1}});
auto two = p2.add_literal({s, {2}});
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto sum = p2.add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{1}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat1 = p2.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concat2b = p2.add_instruction(b, concat2);
auto sum = p2.add_instruction(migraphx::op::add{}, concat1, concat2b);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{0}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal(1);
auto oneb = p2.add_instruction(b, one);
auto two = p2.add_literal(2);
auto twob = p2.add_instruction(b, two);
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, oneb, twob);
auto sum = p2.add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment