Commit 2be7e299 authored by Paul's avatar Paul
Browse files

Format

parent a9d449a0
......@@ -207,19 +207,19 @@ struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs = match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(
match::either_arg(0, 1)(is_dot_const_inputs.bind("dot"),
match::name("broadcast", "multibroadcast").bind("c")));
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
std::cout << "find_mul_dot" << std::endl;
m.debug_print(ins->inputs());
......@@ -229,40 +229,41 @@ struct find_mul_dot
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if (std::count_if(c_strides.begin(), c_strides.end(), [](auto s) {
return s != 0;
}) > 1)
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if (not x_ins->can_eval())
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins = m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if (c_strides.back() == 1) {
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if (c_strides[c_strides.size() - 2] == 1) {
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if (c_ins->get_shape().scalar())
else if(c_ins->get_shape().scalar())
{
if (a_ins->can_eval())
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else {
else
{
return;
}
if (contains({a_ins, b_ins}, m.end()))
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
......@@ -274,7 +275,8 @@ struct find_dot_mul
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(match::either_arg(0, 1)(const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z")));
auto mul = match::name("mul")(match::either_arg(0, 1)(
const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
......@@ -290,32 +292,32 @@ struct find_dot_mul
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if (std::count_if(d_strides.begin(), d_strides.end(), [](auto s) {
return s != 0;
}) > 1)
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if (not d_ins->get_shape().scalar()) {
if (d_strides.back() == 1 and not b_ins->can_eval())
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if (d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto broadcast_v = d_ins->get_operator().to_value();
broadcast_v["out_lens"] = c_ins->get_shape().lens();
auto db_ins = m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_ins);
if (c_ins == b_ins)
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
a_ins = cd_ins;
b_ins = z_ins;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment