Unverified Commit 9a70050b authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Multibroadcast find_mul_conv (#1384)

Change find_mul_conv to work with multibroadcast also. Checks the strides instead of the broadcast axis.
parent 97a1ed2d
...@@ -57,12 +57,14 @@ auto conv_const_weights() ...@@ -57,12 +57,14 @@ auto conv_const_weights()
auto reduction() { return match::name_contains("reduce"); } auto reduction() { return match::name_contains("reduce"); }
// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv struct find_mul_conv
{ {
auto matcher() const auto matcher() const
{ {
return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"), return match::name("mul")(
match::name("broadcast").bind("a"))); match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast", "multibroadcast").bind("a")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -72,14 +74,35 @@ struct find_mul_conv ...@@ -72,14 +74,35 @@ struct find_mul_conv
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"]; auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator()); const auto& a_input_lens = a_ins->inputs().front()->get_shape().lens();
if(broadcast_op.axis != 1)
std::size_t num_not_one_dims = std::count_if(
a_input_lens.cbegin(), a_input_lens.cend(), [](auto dim) { return dim != 1; });
if(num_not_one_dims > 1)
return;
// check broadcasted along channels
const auto& a_lens = a_ins->get_shape().lens();
const auto& a_strides = a_ins->get_shape().strides();
auto is_broadcasted_axis = [](auto len, auto stride) { return len == 1 or stride == 0; };
if(a_strides.at(1) != 1)
return; return;
if(not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;
if(not std::equal(a_lens.begin() + 2,
a_lens.end(),
a_strides.begin() + 2,
a_strides.end(),
is_broadcasted_axis))
return;
auto sq = m.insert_instruction(ins, make_op("squeeze"), a_ins->inputs().front());
auto new_a = m.insert_instruction( auto new_a = m.insert_instruction(
ins, ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), sq);
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins); auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction( auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul); ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
......
...@@ -236,6 +236,105 @@ TEST_CASE(simplify_mul_conv1) ...@@ -236,6 +236,105 @@ TEST_CASE(simplify_mul_conv1)
EXPECT(new_conv->outputs().front()->name() != "mul"); EXPECT(new_conv->outputs().front()->name() != "mul");
} }
TEST_CASE(simplify_mul_conv2)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto unsq_a = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), a);
auto b = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 256, 14, 14}}}), unsq_a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
// len = 1 case
TEST_CASE(simplify_mul_conv3)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 1, 1}, {1, 18, 1}}));
auto b =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
// Previously broadcasted literal case, should skip
TEST_CASE(simplify_mul_conv_skip1)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 14, 14}, {1, 0, 0}}));
auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() == "mul");
}
// Another previously broadcasted literal case, should skip
TEST_CASE(simplify_mul_conv_skip2)
{
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
m.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = m.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto a = m.add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 14, 14}, {1, 0, 0}}));
auto b =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 256, 14, 14}}}), a);
auto mul = m.add_instruction(migraphx::make_op("mul"), conv, b);
m.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(m);
auto new_conv =
std::find_if(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() == "mul");
}
TEST_CASE(simplify_mul_slice_conv1) TEST_CASE(simplify_mul_slice_conv1)
{ {
migraphx::module m1; migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment