Unverified Commit e66968a2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Optimize multiply across slices (#568)



* Add initial optimization when using a mul over a sliced convolution

* Formatting

* Add more tests

* Formatting

* Convert to an assert

* Check if used once

* Formatting

* Add test with horiz fusion

* Formatting

* Optimize nested slice

* Formatting

* Fix test

* Add const refs

* Remove unnecessary assert
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 70ba8213
...@@ -63,6 +63,90 @@ struct find_mul_conv ...@@ -63,6 +63,90 @@ struct find_mul_conv
} }
}; };
struct find_mul_slice_conv
{
static auto conv()
{
return match::name("convolution")(
match::all_of[match::outputs()](match::name("slice")),
match::args(match::any(), match::is_constant().bind("w")));
}
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("slice")(match::used_once(), match::arg(0)(conv().bind("conv")))
.bind("slice"),
match::name("broadcast")(match::is_constant()).bind("a")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto slice_ins = r.instructions["slice"];
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
return;
auto slice_op = any_cast<op::slice>(slice_ins->get_operator());
if(slice_op.axes.size() != 1)
return;
if(slice_op.axes.front() != 1)
return;
auto slice_idx = std::distance(conv_ins, slice_ins);
if(std::any_of(conv_ins->outputs().begin(), conv_ins->outputs().end(), [&](auto i) {
if(i == slice_ins)
return false;
if(std::distance(conv_ins, i) < slice_idx)
return true;
auto sop = any_cast<op::slice>(i->get_operator());
if(sop.axes != slice_op.axes)
return true;
if(std::max(sop.starts.front(), slice_op.starts.front()) <
std::min(sop.ends.front(), slice_op.ends.front()))
return true;
return false;
}))
return;
auto w_slice_op = slice_op;
w_slice_op.axes = {0};
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction(
ins, op::broadcast{0, slice_w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0)
sliced_weights.push_back(
p.insert_instruction(ins, op::slice{{0}, {0}, slice_op.starts}, w_ins));
sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(
p.insert_instruction(ins, op::slice{{0}, {slice_op.ends}, {end_axis}}, w_ins));
auto new_weights = p.insert_instruction(ins, op::concat{0}, sliced_weights);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
assert(conv_ins->get_shape() == new_conv->get_shape());
auto slice1 = p.insert_instruction(ins, slice_op, new_conv);
assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
for(auto output : conv_ins->outputs())
if(output != slice_ins)
instruction::replace_argument(output, conv_ins, new_conv);
}
};
// a * (x + b) => a * x + a * b // a * (x + b) => a * x + a * b
struct find_mul_add struct find_mul_add
{ {
...@@ -180,7 +264,7 @@ struct find_concat_op ...@@ -180,7 +264,7 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::name("add", "multiply", "relu", "broadcast"), match::used_once())); match::name("add", "mul", "relu", "broadcast"), match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -209,7 +293,7 @@ struct find_concat_op ...@@ -209,7 +293,7 @@ struct find_concat_op
if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1) if(x->inputs().size() > 2 or x->inputs().empty() or x->outputs().size() > 1)
return {start, last}; return {start, last};
auto&& name = x->name(); auto&& name = x->name();
if(not contains({"add", "multiply", "relu", "broadcast"}, name)) if(not contains({"add", "mul", "relu", "broadcast"}, name))
return {start, last}; return {start, last};
auto op = x->get_operator(); auto op = x->get_operator();
auto iaxis = axis; auto iaxis = axis;
...@@ -865,6 +949,7 @@ void simplify_algebra::apply(program& p) const ...@@ -865,6 +949,7 @@ void simplify_algebra::apply(program& p) const
find_add_convs{}, find_add_convs{},
find_conv_dot_horiz_fusion{}, find_conv_dot_horiz_fusion{},
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
......
...@@ -4,11 +4,14 @@ ...@@ -4,11 +4,14 @@
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <unordered_set> #include <unordered_set>
#include <map>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -150,6 +153,75 @@ struct find_transpose ...@@ -150,6 +153,75 @@ struct find_transpose
} }
}; };
struct find_nested_slice
{
auto matcher() const { return match::name("slice")(match::arg(0)(match::name("slice"))); }
using axes_map = std::map<std::size_t, std::pair<std::size_t, std::size_t>>;
static axes_map get_axes(instruction_ref ins)
{
axes_map result;
auto op = any_cast<op::slice>(ins->get_operator());
for(std::size_t i = 0; i < op.axes.size(); i++)
{
result[op.axes[i]] = std::make_pair(op.starts[i], op.ends[i]);
}
return result;
}
static axes_map merge(const axes_map& m1, const axes_map& m2)
{
axes_map result;
// Non overlapping
for(auto&& p : m1)
{
if(contains(m2, p.first))
continue;
result[p.first] = p.second;
}
for(auto&& p : m2)
{
if(contains(m1, p.first))
continue;
result[p.first] = p.second;
}
// Overlapping
for(auto&& p1 : m1)
{
if(not contains(m2, p1.first))
continue;
auto&& v1 = p1.second;
auto&& v2 = m2.at(p1.first);
auto start = v1.first + v2.first;
auto end = start + (v2.second - v2.first);
result[p1.first] = std::make_pair(start, end);
}
return result;
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto slice = ins->inputs().front();
auto input = slice->inputs().front();
auto a1 = get_axes(ins);
auto a2 = get_axes(slice);
auto axes = merge(a2, a1);
auto op = op::slice{};
for(auto&& pp : axes)
{
op.axes.push_back(pp.first);
op.starts.push_back(pp.second.first);
op.ends.push_back(pp.second.second);
}
p.replace_instruction(ins, op, input);
}
};
struct find_concat_transpose struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
...@@ -215,22 +287,14 @@ void simplify_reshapes::apply(program& p) const ...@@ -215,22 +287,14 @@ void simplify_reshapes::apply(program& p) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
auto end = std::prev(p.end()); match::find_matches(p,
for(auto ins : iterator_for(p)) find_nop_reshapes{},
{ find_reshaper{},
if(ins == end and ins->name() == "contiguous") find_transpose{},
continue; find_concat_transpose{},
// Skip possible dead instructions find_nested_slice{},
if(ins->outputs().empty() and ins != end) find_nested_concat{});
continue; dead_code_elimination{}.apply(p);
match::find_matches(p,
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{},
find_nested_concat{});
}
} }
} }
......
...@@ -44,7 +44,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -44,7 +44,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
decompose{}, decompose{},
dead_code_elimination{}, dead_code_elimination{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
...@@ -57,10 +56,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -57,10 +56,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_common_subexpression{}, eliminate_common_subexpression{},
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, simplify_reshapes{},
simplify_algebra{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
remap{}, remap{},
......
...@@ -207,6 +207,88 @@ TEST_CASE(simplify_mul_conv1) ...@@ -207,6 +207,88 @@ TEST_CASE(simplify_mul_conv1)
EXPECT(new_conv->outputs().front()->name() != "mul"); EXPECT(new_conv->outputs().front()->name() != "mul");
} }
TEST_CASE(simplify_mul_slice_conv1)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = p1.add_instruction(migraphx::op::add{}, mul, slice2);
p1.add_instruction(pass_op{}, add);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p2.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto wslice1 = p2.add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto a = p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p2.add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a);
auto mul = p2.add_instruction(migraphx::op::mul{}, b, wslice1);
auto wslice2 = p2.add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat = p2.add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = p2.add_instruction(migraphx::op::convolution{}, x, concat);
auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = p2.add_instruction(migraphx::op::add{}, slice1, slice2);
p2.add_instruction(pass_op{}, add);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {383}, {767}}, conv);
auto add = p1.add_instruction(migraphx::op::add{}, mul, slice2);
p1.add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1 == p2);
}
TEST_CASE(simplify_mul_slice_conv_not_all_slice)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b);
auto c = p1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
auto add = p1.add_instruction(migraphx::op::add{}, conv, c);
auto concat = p1.add_instruction(migraphx::op::concat{1}, mul, add);
p1.add_instruction(pass_op{}, concat);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1 == p2);
}
TEST_CASE(simplify_mul_add) TEST_CASE(simplify_mul_add)
{ {
migraphx::program p1; migraphx::program p1;
...@@ -1383,6 +1465,58 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2) ...@@ -1383,6 +1465,58 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a1 =
p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a1);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b1);
auto a2 =
p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a2);
auto add1 = p1.add_instruction(migraphx::op::add{}, mul, b2);
auto a3 =
p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a3);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add2 = p1.add_instruction(migraphx::op::add{}, slice2, b3);
p1.add_instruction(pass_op{}, add1, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p2.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto wslice1 = p2.add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto a1 =
p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = p2.add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a1);
auto mul = p2.add_instruction(migraphx::op::mul{}, b1, wslice1);
auto wslice2 = p2.add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = p2.add_instruction(migraphx::op::convolution{}, x, concat1);
auto a2 =
p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto a3 =
p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = p2.add_instruction(migraphx::op::concat{}, a2, a3);
auto b4 = p2.add_instruction(migraphx::op::broadcast{1, {1, 768, 17, 17}}, concat2);
auto add = p2.add_instruction(migraphx::op::add{}, conv, b4);
auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, add);
auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, add);
p2.add_instruction(pass_op{}, slice1, slice2);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(reorder_reshape_slice) TEST_CASE(reorder_reshape_slice)
{ {
std::vector<int64_t> perm0 = {0, 2, 1, 3}; std::vector<int64_t> perm0 = {0, 2, 1, 3};
......
...@@ -376,4 +376,64 @@ TEST_CASE(multibroadcast_simplify) ...@@ -376,4 +376,64 @@ TEST_CASE(multibroadcast_simplify)
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
TEST_CASE(double_slice1)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
p1.add_instruction(pass_op{}, slice2);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
p2.add_instruction(pass_op{}, slice);
}
EXPECT(p1 == p2);
}
TEST_CASE(double_slice2)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
p1.add_instruction(pass_op{}, slice2);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
p2.add_instruction(pass_op{}, slice);
}
EXPECT(p1 == p2);
}
TEST_CASE(double_slice_multi_axes)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
p1.add_instruction(pass_op{}, slice2);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
p2.add_instruction(pass_op{}, slice);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment