Commit 49dc6d12 authored by Paul's avatar Paul
Browse files

Match mul_add reshapes

parent ce2423ce
......@@ -587,13 +587,14 @@ struct find_reshape_cont
};
// match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper()
template<class... Ms>
auto match_transpose_contiguous_reshaper(Ms... ms)
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins")))
match::used_once(), match::args(match::transpose_shape(ms...).bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};
......@@ -626,6 +627,37 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct find_mul_add_transpose_contiguous_reshaper_gemm
{
auto matcher() const
{
auto pw = match::name("mul", "add")(match::used_once(), match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x")));
return match::name("dot")(match::either_arg(0, 1)(match_transpose_contiguous_reshaper(pw.bind("pointwise")), match::is_constant()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto x_ins = r.instructions["x"];
auto c_ins = r.instructions["c"];
auto pw_ins = r.instructions["pointwise"];
auto insert_reshapes = [&](auto x) {
auto t = m.insert_instruction(ins, trans_ins->get_operator(), x);
auto c = m.insert_instruction(ins, make_op("contiguous"), t);
return m.insert_instruction(ins, reshaper_ins->get_operator(), c);
};
if (x_ins->name() == "mul")
{
x_ins = m.insert_instruction(ins, make_op("mul"), {insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])});
}
auto y_ins = m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)});
m.replace_instruction(reshaper_ins, y_ins);
}
};
struct find_slice_transpose
{
auto matcher() const
......@@ -844,6 +876,7 @@ void simplify_reshapes::apply(module& m) const
find_transpose_slice{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{},
find_mul_add_transpose_contiguous_reshaper_gemm{},
find_reshape_gemm{});
dead_code_elimination{}.apply(m);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment