Commit 2be7e299 authored by Paul's avatar Paul
Browse files

Format

parent a9d449a0
...@@ -207,10 +207,10 @@ struct find_mul_dot ...@@ -207,10 +207,10 @@ struct find_mul_dot
{ {
auto matcher() const auto matcher() const
{ {
auto is_dot_const_inputs = match::name("dot")(match::any_of[match::inputs()](match::is_constant())); auto is_dot_const_inputs =
return match::name("mul")( match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
match::either_arg(0, 1)(is_dot_const_inputs.bind("dot"), return match::name("mul")(match::either_arg(0, 1)(
match::name("broadcast", "multibroadcast").bind("c"))); is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -229,40 +229,41 @@ struct find_mul_dot ...@@ -229,40 +229,41 @@ struct find_mul_dot
const auto& c_strides = c_ins->get_shape().strides(); const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero // There should only be one stride that is not zero
if (std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return s != 0;
}) > 1)
return; return;
auto add_mul_const = [&](instruction_ref x_ins) { auto add_mul_const = [&](instruction_ref x_ins) {
if (not x_ins->can_eval()) if(not x_ins->can_eval())
return m.end(); return m.end();
auto broadcast_v = c_ins->get_operator().to_value(); auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens(); broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins = m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs()); auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins); return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
}; };
if (c_strides.back() == 1) { if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins); b_ins = add_mul_const(b_ins);
} }
else if (c_strides[c_strides.size() - 2] == 1) { else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins); a_ins = add_mul_const(a_ins);
} }
else if (c_ins->get_shape().scalar()) else if(c_ins->get_shape().scalar())
{ {
if (a_ins->can_eval()) if(a_ins->can_eval())
a_ins = add_mul_const(a_ins); a_ins = add_mul_const(a_ins);
else else
b_ins = add_mul_const(b_ins); b_ins = add_mul_const(b_ins);
} }
else { else
{
return; return;
} }
if (contains({a_ins, b_ins}, m.end())) if(contains({a_ins, b_ins}, m.end()))
return; return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins); m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
...@@ -274,7 +275,8 @@ struct find_dot_mul ...@@ -274,7 +275,8 @@ struct find_dot_mul
auto matcher() const auto matcher() const
{ {
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant()); auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(match::either_arg(0, 1)(const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z"))); auto mul = match::name("mul")(match::either_arg(0, 1)(
const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c"))); return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
} }
...@@ -290,25 +292,25 @@ struct find_dot_mul ...@@ -290,25 +292,25 @@ struct find_dot_mul
const auto& d_strides = d_ins->get_shape().strides(); const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero // There should only be one stride that is not zero
if (std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return s != 0;
}) > 1)
return; return;
if (not d_ins->get_shape().scalar()) { if(not d_ins->get_shape().scalar())
if (d_strides.back() == 1 and not b_ins->can_eval()) {
if(d_strides.back() == 1 and not b_ins->can_eval())
return; return;
if (d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval()) if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return; return;
} }
auto broadcast_v = d_ins->get_operator().to_value(); auto broadcast_v = d_ins->get_operator().to_value();
broadcast_v["out_lens"] = c_ins->get_shape().lens(); broadcast_v["out_lens"] = c_ins->get_shape().lens();
auto db_ins = m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs()); auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_ins); auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_ins);
if (c_ins == b_ins) if(c_ins == b_ins)
{ {
a_ins = z_ins; a_ins = z_ins;
b_ins = cd_ins; b_ins = cd_ins;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment