Commit 797a213c authored by Paul's avatar Paul
Browse files

Fuse dot adds with constants

parent a1efd676
...@@ -850,6 +850,31 @@ struct find_add_convs ...@@ -850,6 +850,31 @@ struct find_add_convs
} }
}; };
struct find_add_dots
{
auto matcher() const
{
auto dot_const_weights = match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
auto dot_const_inputs = match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
return match::name("add")(match::any_of(
match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = r.instructions["a"];
auto b = r.instructions["b"];
auto n = ins->get_shape().lens().size();
auto x = m.insert_instruction(ins, make_op("concat", {{"axis", (n-1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(ins, make_op("concat", {{"axis", (n-2)}}), a->inputs()[1], b->inputs()[1]);
m.replace_instruction(ins, make_op("dot"), x, w);
}
};
MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
{ {
auto pred = [&](auto name) { auto pred = [&](auto name) {
...@@ -1179,6 +1204,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1179,6 +1204,7 @@ void simplify_algebra::apply(module& m) const
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
find_add_dots{},
find_conv_dot_horiz_fusion{}, find_conv_dot_horiz_fusion{},
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
......
...@@ -1545,6 +1545,97 @@ TEST_CASE(simplify_split_between_add) ...@@ -1545,6 +1545,97 @@ TEST_CASE(simplify_split_between_add)
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
TEST_CASE(simplify_split_binary_dot)
{
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1);
auto slice12 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1);
auto slice21 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2);
auto slice22 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice21);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice12, slice22);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
m1.add_return({ret});
};
migraphx::module m2;
{
// TODO: Fuse these dot operators
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), sum);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), sum);
auto ret = m2.add_instruction(migraphx::make_op("mul"), slice1, slice2);
m2.add_return({ret});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_split_binary_same_input)
{
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1);
auto slice12 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice12);
auto slice21 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2);
auto slice22 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice21, slice22);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
m1.add_return({ret});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_dot_horiz) TEST_CASE(simplify_dot_horiz)
{ {
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment