Unverified Commit 905ebd5e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fuse operators that partially match in a concat operator (#539)



* Fuse operators that partially match in a concat operator

* Formatting

* Remove unused matcher

* Properly calculate the output_lens

* Formatting

* Formatting

* Formatting

* Fix tidy issue
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
parent d7b8164c
......@@ -12,7 +12,18 @@ void group_by(Iterator start, Iterator last, Output out, Predicate pred)
{
while(start != last)
{
auto it = std::partition(start, last, [&](auto x) { return pred(x, *start); });
auto it = std::partition(start, last, [&](auto&& x) { return pred(x, *start); });
out(start, it);
start = it;
}
}
template <class Iterator, class Output, class Predicate>
void group_unique(Iterator start, Iterator last, Output out, Predicate pred)
{
while(start != last)
{
auto it = std::find_if(start, last, [&](auto&& x) { return not pred(*start, x); });
out(start, it);
start = it;
}
......
......@@ -33,15 +33,6 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w")));
}
MIGRAPHX_PRED_MATCHER(args_has_same_ops, instruction_ref ins)
{
if(ins->inputs().empty())
return true;
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto j) {
return j->get_operator() == ins->inputs().front()->get_operator();
});
}
struct find_mul_conv
{
auto matcher() const
......@@ -182,70 +173,85 @@ struct find_inner_broadcast
}
};
struct find_concat_unary
struct find_concat_op
{
auto matcher() const
{
return match::name("concat")(args_has_same_ops(),
match::arg(0)(match::nargs(1),
match::name("relu", "broadcast").bind("x"),
match::used_once()));
return match::name("concat")(match::any_of[match::inputs()](
match::name("add", "multiply", "relu", "broadcast"), match::used_once()));
}
void apply(program& p, match::matcher_result r) const
template <class Iterator>
static std::vector<std::size_t> get_output_lens(Iterator start, Iterator last, std::size_t axis)
{
auto ins = r.result;
auto x = r.instructions["x"];
auto op = x->get_operator();
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
// Adjust broadcast lens
if(op.name() == "broadcast")
assert(start != last);
std::size_t dim = 0;
for(auto ins : range(start, last))
{
auto b = any_cast<op::broadcast>(op);
if(b.axis != axis)
return;
b.broadcast_lens = ins->get_shape().lens();
op = b;
axis = 0;
dim += ins->get_shape().lens().at(axis);
}
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
return i->inputs().front();
});
auto concat = p.insert_instruction(ins, op::concat{axis}, inputs);
p.replace_instruction(ins, op, concat);
auto lens = (*start)->get_shape().lens();
lens[axis] = dim;
return lens;
}
};
struct find_concat_binary
{
auto matcher() const
void apply(program& p, const match::matcher_result& r) const
{
return match::name("concat")(args_has_same_ops(),
match::arg(0)(match::nargs(2),
match::name("add", "multiply").bind("x"),
match::used_once()));
}
auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto op = x->get_operator();
auto concat_op = ins->get_operator();
auto xinputs = ins->inputs();
std::transform(xinputs.begin(), xinputs.end(), xinputs.begin(), [&](auto i) {
return i->inputs().front();
});
auto yinputs = ins->inputs();
std::transform(yinputs.begin(), yinputs.end(), yinputs.begin(), [&](auto i) {
return i->inputs().back();
});
auto xconcat = p.insert_instruction(ins, concat_op, xinputs);
auto yconcat = p.insert_instruction(ins, concat_op, yinputs);
p.replace_instruction(ins, op, xconcat, yconcat);
auto each = [&](auto start, auto last) -> std::vector<instruction_ref> {
if(std::distance(start, last) < 2)
return {start, last};
auto x = *start;
if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1)
return {start, last};
auto&& name = x->name();
if(not contains({"add", "multiply", "relu", "broadcast"}, name))
return {start, last};
auto op = x->get_operator();
auto iaxis = axis;
// Adjust broadcast lens
if(op.name() == "broadcast")
{
auto b = any_cast<op::broadcast>(op);
if(b.axis != iaxis)
return {start, last};
b.broadcast_lens = get_output_lens(start, last, iaxis);
op = b;
iaxis = 0;
}
std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
{
std::vector<instruction_ref> inputs;
std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
return j->inputs().at(i);
});
auto concat = p.insert_instruction(ins, op::concat{iaxis}, inputs);
concats.push_back(concat);
}
auto y = p.insert_instruction(ins, op, concats);
return {y};
};
std::vector<instruction_ref> args;
auto update_args = [&](auto start, auto last) {
auto x = each(start, last);
args.insert(args.end(), x.begin(), x.end());
};
auto pred = [](auto i, auto j) {
return i->get_operator() == j->get_operator() and
i->inputs().size() == i->inputs().size() and
i->outputs().size() == i->outputs().size();
};
group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred);
if(args.size() == 1)
p.replace_instruction(ins, args.front());
else
p.replace_instruction(ins, op::concat{axis}, args);
}
};
......@@ -724,8 +730,7 @@ void simplify_algebra::apply(program& p) const
find_div_const{},
find_sub_const{},
find_rsqrt{},
find_concat_unary{},
find_concat_binary{},
find_concat_op{},
find_split_concat{},
find_splits{});
dead_code_elimination{}.apply(p);
......
......@@ -434,6 +434,76 @@ TEST_CASE(simplify_concat_add_relu)
EXPECT(p1 == p2);
}
TEST_CASE(simplify_concat_add_relu_partial)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal({s, {1}});
auto two = p1.add_literal({s, {2}});
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto sum3 = p1.add_instruction(migraphx::op::add{}, x, y);
auto concat = p1.add_instruction(migraphx::op::concat{0}, sum3, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal({s, {1}});
auto two = p2.add_literal({s, {2}});
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto sum1 = p2.add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto concat = p2.add_instruction(migraphx::op::concat{0}, sum2, relu);
p2.add_instruction(pass_op{}, concat);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_concat_add_relu_partial_broadcast)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
auto concat = p1.add_instruction(migraphx::op::concat{1}, sum, oneb, twob);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat1);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{1}, sum, concatb);
p2.add_instruction(pass_op{}, concat2);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment