"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0e5682cd675d31bcbd3e93dd87b7837e398f83ef"
Commit 97fd3c3b authored by Paul's avatar Paul
Browse files

Formatting

parent 0c212157
...@@ -369,10 +369,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -369,10 +369,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
{
return ins->can_eval();
}
template <class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
......
...@@ -9,10 +9,7 @@ ...@@ -9,10 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
auto lit_broadcast() auto lit_broadcast() { return match::any_of(match::name("@literal"), match::name("broadcast")); }
{
return match::any_of(match::name("@literal"), match::name("broadcast"));
}
auto not_lit_broadcast() auto not_lit_broadcast()
{ {
return match::none_of(match::name("@literal"), match::name("broadcast")); return match::none_of(match::name("@literal"), match::name("broadcast"));
...@@ -27,8 +24,11 @@ struct find_mul_conv ...@@ -27,8 +24,11 @@ struct find_mul_conv
{ {
auto matcher() const auto matcher() const
{ {
return match::name("mul")( return match::name("mul")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::name("conv")(match::used_once(), match::args(match::any(), match::is_constant().bind("w"))).bind("conv"), match::name("broadcast").bind("a"))); match::name("conv")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")))
.bind("conv"),
match::name("broadcast").bind("a")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -39,12 +39,14 @@ struct find_mul_conv ...@@ -39,12 +39,14 @@ struct find_mul_conv
auto w_ins = r.instructions["w"]; auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator()); auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if (broadcast_op.axis != 1) if(broadcast_op.axis != 1)
return; return;
auto new_a = p.insert_instruction(ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front()); auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins); auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
auto new_conv = p.insert_instruction(ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul); auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv); p.replace_instruction(ins, new_conv);
} }
}; };
...@@ -88,7 +90,10 @@ struct find_add_lit_broadcast ...@@ -88,7 +90,10 @@ struct find_add_lit_broadcast
} }
}; };
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}); } void simplify_algebra::apply(program& p) const
{
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -106,8 +106,9 @@ TEST_CASE(simplify_mul_conv1) ...@@ -106,8 +106,9 @@ TEST_CASE(simplify_mul_conv1)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}}); auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}})); auto w =
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1},{2, 2},{1, 1}}, x, w); p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}})); auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a); auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b); auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment