Commit 1dd8a1fd authored by Paul's avatar Paul
Browse files

Format

parent 797a213c
...@@ -854,10 +854,12 @@ struct find_add_dots ...@@ -854,10 +854,12 @@ struct find_add_dots
{ {
auto matcher() const auto matcher() const
{ {
auto dot_const_weights = match::name("dot")(match::used_once(), match::arg(1)(match::is_constant())); auto dot_const_weights =
auto dot_const_inputs = match::name("dot")(match::used_once(), match::arg(0)(match::is_constant())); match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
return match::name("add")(match::any_of( auto dot_const_inputs =
match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")), match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
return match::name("add")(
match::any_of(match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b")))); match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b"))));
} }
...@@ -869,8 +871,10 @@ struct find_add_dots ...@@ -869,8 +871,10 @@ struct find_add_dots
auto n = ins->get_shape().lens().size(); auto n = ins->get_shape().lens().size();
auto x = m.insert_instruction(ins, make_op("concat", {{"axis", (n-1)}}), a->inputs()[0], b->inputs()[0]); auto x = m.insert_instruction(
auto w = m.insert_instruction(ins, make_op("concat", {{"axis", (n-2)}}), a->inputs()[1], b->inputs()[1]); ins, make_op("concat", {{"axis", (n - 1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(
ins, make_op("concat", {{"axis", (n - 2)}}), a->inputs()[1], b->inputs()[1]);
m.replace_instruction(ins, make_op("dot"), x, w); m.replace_instruction(ins, make_op("dot"), x, w);
} }
}; };
......
...@@ -1551,13 +1551,17 @@ TEST_CASE(simplify_split_binary_dot) ...@@ -1551,13 +1551,17 @@ TEST_CASE(simplify_split_binary_dot)
migraphx::module m1; migraphx::module m1;
{ {
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}}); auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1)); auto w1 =
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}}); auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2)); auto w2 =
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b); auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction( auto slice11 = m1.add_instruction(
...@@ -1579,13 +1583,17 @@ TEST_CASE(simplify_split_binary_dot) ...@@ -1579,13 +1583,17 @@ TEST_CASE(simplify_split_binary_dot)
{ {
// TODO: Fuse these dot operators // TODO: Fuse these dot operators
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}}); auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1)); auto w1 =
auto w1b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1); m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b); auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}}); auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2)); auto w2 =
auto w2b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2); m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b); auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2); auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
...@@ -1607,13 +1615,17 @@ TEST_CASE(simplify_split_binary_same_input) ...@@ -1607,13 +1615,17 @@ TEST_CASE(simplify_split_binary_same_input)
migraphx::module m1; migraphx::module m1;
{ {
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}}); auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1)); auto w1 =
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}}); auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2)); auto w2 =
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b); auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction( auto slice11 = m1.add_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment