"vscode:/vscode.git/clone" did not exist on "a03570a0200f4356079fdf23beae2c717810accc"
Commit 1dd8a1fd authored by Paul's avatar Paul
Browse files

Format

parent 797a213c
......@@ -854,23 +854,27 @@ struct find_add_dots
{
auto matcher() const
{
auto dot_const_weights = match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
auto dot_const_inputs = match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
return match::name("add")(match::any_of(
match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b"))));
auto dot_const_weights =
match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
auto dot_const_inputs =
match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
return match::name("add")(
match::any_of(match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = r.instructions["a"];
auto b = r.instructions["b"];
auto a = r.instructions["a"];
auto b = r.instructions["b"];
auto n = ins->get_shape().lens().size();
auto x = m.insert_instruction(ins, make_op("concat", {{"axis", (n-1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(ins, make_op("concat", {{"axis", (n-2)}}), a->inputs()[1], b->inputs()[1]);
auto x = m.insert_instruction(
ins, make_op("concat", {{"axis", (n - 1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(
ins, make_op("concat", {{"axis", (n - 2)}}), a->inputs()[1], b->inputs()[1]);
m.replace_instruction(ins, make_op("dot"), x, w);
}
};
......
......@@ -1551,27 +1551,31 @@ TEST_CASE(simplify_split_binary_dot)
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto w1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto w2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction(
auto slice11 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1);
auto slice12 = m1.add_instruction(
auto slice12 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1);
auto slice21 = m1.add_instruction(
auto slice21 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2);
auto slice22 = m1.add_instruction(
auto slice22 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice21);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice12, slice22);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
m1.add_return({ret});
};
......@@ -1579,19 +1583,23 @@ TEST_CASE(simplify_split_binary_dot)
{
// TODO: Fuse these dot operators
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto w1 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto w2 =
m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
auto slice1 = m2.add_instruction(
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), sum);
auto slice2 = m2.add_instruction(
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), sum);
auto ret = m2.add_instruction(migraphx::make_op("mul"), slice1, slice2);
......@@ -1607,27 +1615,31 @@ TEST_CASE(simplify_split_binary_same_input)
migraphx::module m1;
{
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto w1 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto w2 =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction(
auto slice11 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1);
auto slice12 = m1.add_instruction(
auto slice12 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice12);
auto slice21 = m1.add_instruction(
auto slice21 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2);
auto slice22 = m1.add_instruction(
auto slice22 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice21, slice22);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
m1.add_return({ret});
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment