Commit 1dd8a1fd authored by Paul's avatar Paul
Browse files

Format

parent 797a213c
......@@ -854,10 +854,12 @@ struct find_add_dots
{
auto matcher() const
{
auto dot_const_weights = match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
auto dot_const_inputs = match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
return match::name("add")(match::any_of(
match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
auto dot_const_weights =
match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
auto dot_const_inputs =
match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
return match::name("add")(
match::any_of(match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b"))));
}
......@@ -869,8 +871,10 @@ struct find_add_dots
auto n = ins->get_shape().lens().size();
auto x = m.insert_instruction(ins, make_op("concat", {{"axis", (n-1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(ins, make_op("concat", {{"axis", (n-2)}}), a->inputs()[1], b->inputs()[1]);
auto x = m.insert_instruction(
ins, make_op("concat", {{"axis", (n - 1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(
ins, make_op("concat", {{"axis", (n - 2)}}), a->inputs()[1], b->inputs()[1]);
m.replace_instruction(ins, make_op("dot"), x, w);
}
};
......
......@@ -1551,13 +1551,17 @@ TEST_CASE(simplify_split_binary_dot)
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto w1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto w2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction(
......@@ -1579,13 +1583,17 @@ TEST_CASE(simplify_split_binary_dot)
{
// TODO: Fuse these dot operators
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto w1 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto w2 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
......@@ -1607,13 +1615,17 @@ TEST_CASE(simplify_split_binary_same_input)
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto w1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto w2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment