Commit 65ea9194 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Fuse the add of two convolutions (#386)

* Fuse convolution adds

* Formatting

* Fuse more 1x1 convs

* Add some tests

* Formatting

* Add test for 1x1

* Add verification for add-conv fusions

* Fix stride calculation

* Formatting

* Add more tests

* Rename tests
parent 5f2767aa
......@@ -29,7 +29,7 @@ struct as_shape
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
assert(inputs.front().elements() == s.elements());
assert(inputs.front().elements() >= s.elements());
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -3,6 +3,9 @@
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
......@@ -164,6 +167,109 @@ struct find_inner_broadcast
}
};
bool axis_equal(const std::vector<std::size_t>& x,
const std::vector<std::size_t>& y,
std::size_t axis)
{
return x.size() == y.size() and x.size() > axis and
std::equal(x.begin(), x.begin() + axis, y.begin()) and
std::equal(x.begin() + axis + 1, x.end(), y.begin() + axis + 1);
}
bool axis_shape_equal(const shape& x, const shape& y, std::size_t axis)
{
// TODO: Check strides
return axis_equal(x.lens(), y.lens(), axis);
}
struct find_add_convs
{
auto matcher() const
{
return match::name("add")(
match::args(conv_const_weights().bind("a"), conv_const_weights().bind("b")));
}
static bool symmetrical_strides(const op::convolution& op)
{
return op.stride[0] == op.stride[1];
}
static std::size_t compute_stride_factor(const op::convolution& x, const op::convolution& y)
{
if(not symmetrical_strides(x))
return 0;
if(not symmetrical_strides(y))
return 0;
if((x.stride[0] % y.stride[0]) != 0)
return 0;
return x.stride[0] / y.stride[0];
}
static shape compute_stride_shape(const shape& input, std::size_t n)
{
return {input.type(),
{input.lens()[0], input.lens()[1], input.lens()[2] / n, input.lens()[3] / n},
{input.strides()[0],
input.strides()[1],
input.strides()[2] * n,
input.strides()[3] * n}};
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto a_conv = r.instructions["a"];
auto a_input = a_conv->inputs().at(0);
auto a_weights = a_conv->inputs().at(1);
auto b_conv = r.instructions["b"];
auto b_input = b_conv->inputs().at(0);
auto b_weights = b_conv->inputs().at(1);
if(not axis_shape_equal(a_weights->get_shape(), b_weights->get_shape(), 1))
return;
auto a_op = any_cast<op::convolution>(a_conv->get_operator());
auto b_op = any_cast<op::convolution>(b_conv->get_operator());
auto new_op = a_op;
if(a_op != b_op)
{
if(std::tie(a_op.padding, a_op.dilation, a_op.group) ==
std::tie(b_op.padding, b_op.dilation, b_op.group) and
a_weights->get_shape().lens()[2] == 1 and a_weights->get_shape().lens()[3] == 1)
{
if(a_op.stride < b_op.stride)
{
auto n = compute_stride_factor(b_op, a_op);
if(n == 0)
return;
new_op = a_op;
b_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(b_input->get_shape(), n)}, b_input);
}
else if(b_op.stride < a_op.stride)
{
auto n = compute_stride_factor(a_op, b_op);
if(n == 0)
return;
new_op = b_op;
a_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(a_input->get_shape(), n)}, a_input);
}
else
return;
}
else
return;
}
auto concat_input = p.insert_instruction(ins, op::concat{1}, a_input, b_input);
auto concat_weights = p.insert_instruction(ins, op::concat{1}, a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
......@@ -173,6 +279,7 @@ void simplify_algebra::apply(program& p) const
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
find_mul_conv{},
find_mul_add{});
dead_code_elimination{}.apply(p);
......
......@@ -843,6 +843,44 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
}
};
struct test_conv_add : verify_program<test_conv_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(migraphx::op::exp{}, sum);
return p;
}
};
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(migraphx::op::exp{}, sum);
return p;
}
};
struct test_add_relu : verify_program<test_add_relu>
{
migraphx::program create_program() const
......
......@@ -262,4 +262,127 @@ TEST_CASE(simplify_inner_broadcast)
EXPECT(p1 == p2);
}
TEST_CASE(simplify_add_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
EXPECT(s == p.get_shape());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}
TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
EXPECT(s == p.get_shape());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}
TEST_CASE(simplify_add_conv_1x1_diff_strides1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
EXPECT(s == p.get_shape());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}
TEST_CASE(simplify_add_conv_1x1_diff_strides2)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
EXPECT(s == p.get_shape());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}
TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
EXPECT(s == p.get_shape());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}
TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
p.compile(simplify_algebra_target{});
EXPECT(s == p.get_shape());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment