"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "9116b2896fb9b6ae8510f48ce5f12f012b64483f"
Commit 1dd8a1fd authored by Paul's avatar Paul
Browse files

Format

parent 797a213c
...@@ -854,23 +854,27 @@ struct find_add_dots ...@@ -854,23 +854,27 @@ struct find_add_dots
{ {
auto matcher() const auto matcher() const
{ {
auto dot_const_weights = match::name("dot")(match::used_once(), match::arg(1)(match::is_constant())); auto dot_const_weights =
auto dot_const_inputs = match::name("dot")(match::used_once(), match::arg(0)(match::is_constant())); match::name("dot")(match::used_once(), match::arg(1)(match::is_constant()));
return match::name("add")(match::any_of( auto dot_const_inputs =
match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")), match::name("dot")(match::used_once(), match::arg(0)(match::is_constant()));
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b")))); return match::name("add")(
match::any_of(match::args(dot_const_weights.bind("a"), dot_const_weights.bind("b")),
match::args(dot_const_inputs.bind("a"), dot_const_inputs.bind("b"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a = r.instructions["a"]; auto a = r.instructions["a"];
auto b = r.instructions["b"]; auto b = r.instructions["b"];
auto n = ins->get_shape().lens().size(); auto n = ins->get_shape().lens().size();
auto x = m.insert_instruction(ins, make_op("concat", {{"axis", (n-1)}}), a->inputs()[0], b->inputs()[0]); auto x = m.insert_instruction(
auto w = m.insert_instruction(ins, make_op("concat", {{"axis", (n-2)}}), a->inputs()[1], b->inputs()[1]); ins, make_op("concat", {{"axis", (n - 1)}}), a->inputs()[0], b->inputs()[0]);
auto w = m.insert_instruction(
ins, make_op("concat", {{"axis", (n - 2)}}), a->inputs()[1], b->inputs()[1]);
m.replace_instruction(ins, make_op("dot"), x, w); m.replace_instruction(ins, make_op("dot"), x, w);
} }
}; };
......
...@@ -1551,27 +1551,31 @@ TEST_CASE(simplify_split_binary_dot) ...@@ -1551,27 +1551,31 @@ TEST_CASE(simplify_split_binary_dot)
migraphx::module m1; migraphx::module m1;
{ {
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}}); auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1)); auto w1 =
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}}); auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2)); auto w2 =
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b); auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction( auto slice11 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1);
auto slice12 = m1.add_instruction( auto slice12 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1);
auto slice21 = m1.add_instruction( auto slice21 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2);
auto slice22 = m1.add_instruction( auto slice22 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice21); auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice21);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice12, slice22); auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice12, slice22);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
m1.add_return({ret}); m1.add_return({ret});
}; };
...@@ -1579,19 +1583,23 @@ TEST_CASE(simplify_split_binary_dot) ...@@ -1579,19 +1583,23 @@ TEST_CASE(simplify_split_binary_dot)
{ {
// TODO: Fuse these dot operators // TODO: Fuse these dot operators
auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}}); auto x1 = m2.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1)); auto w1 =
auto w1b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1); m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b); auto dot1 = m2.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}}); auto x2 = m2.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2)); auto w2 =
auto w2b = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2); m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b); auto dot2 = m2.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2); auto sum = m2.add_instruction(migraphx::make_op("add"), dot1, dot2);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), sum); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), sum);
auto slice2 = m2.add_instruction( auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), sum); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), sum);
auto ret = m2.add_instruction(migraphx::make_op("mul"), slice1, slice2); auto ret = m2.add_instruction(migraphx::make_op("mul"), slice1, slice2);
...@@ -1607,27 +1615,31 @@ TEST_CASE(simplify_split_binary_same_input) ...@@ -1607,27 +1615,31 @@ TEST_CASE(simplify_split_binary_same_input)
migraphx::module m1; migraphx::module m1;
{ {
auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}}); auto x1 = m1.add_parameter("x1", {migraphx::shape::float_type, {1, 160, 4}});
auto w1 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1)); auto w1 =
auto w1b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 1));
auto w1b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w1);
auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b); auto dot1 = m1.add_instruction(migraphx::make_op("dot"), x1, w1b);
auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}}); auto x2 = m1.add_parameter("x2", {migraphx::shape::float_type, {1, 160, 4}});
auto w2 = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2)); auto w2 =
auto w2b = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2); m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4, 64}}, 2));
auto w2b =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4, 64}}}), w2);
auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b); auto dot2 = m1.add_instruction(migraphx::make_op("dot"), x2, w2b);
auto slice11 = m1.add_instruction( auto slice11 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot1);
auto slice12 = m1.add_instruction( auto slice12 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot1);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice12); auto sum1 = m1.add_instruction(migraphx::make_op("add"), slice11, slice12);
auto slice21 = m1.add_instruction( auto slice21 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), dot2);
auto slice22 = m1.add_instruction( auto slice22 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), dot2);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice21, slice22); auto sum2 = m1.add_instruction(migraphx::make_op("add"), slice21, slice22);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum1, sum2);
m1.add_return({ret}); m1.add_return({ret});
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment