Commit 49dc6d12 authored by Paul's avatar Paul
Browse files

Match mul_add reshapes

parent ce2423ce
...@@ -587,13 +587,14 @@ struct find_reshape_cont ...@@ -587,13 +587,14 @@ struct find_reshape_cont
}; };
// match sequence of transpose --> contiguous --> reshaper_op // match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper() template<class... Ms>
auto match_transpose_contiguous_reshaper(Ms... ms)
{ {
return match::name({"reshape", "squeeze", "unsqueeze"})( return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(), match::used_once(),
match::args( match::args(
match::name("contiguous")( match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins"))) match::used_once(), match::args(match::transpose_shape(ms...).bind("trans_ins")))
.bind("cont_ins"))) .bind("cont_ins")))
.bind("reshaper_ins"); .bind("reshaper_ins");
}; };
...@@ -626,6 +627,37 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -626,6 +627,37 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_mul_add_transpose_contiguous_reshaper_gemm
{
auto matcher() const
{
auto pw = match::name("mul", "add")(match::used_once(), match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x")));
return match::name("dot")(match::either_arg(0, 1)(match_transpose_contiguous_reshaper(pw.bind("pointwise")), match::is_constant()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto x_ins = r.instructions["x"];
auto c_ins = r.instructions["c"];
auto pw_ins = r.instructions["pointwise"];
auto insert_reshapes = [&](auto x) {
auto t = m.insert_instruction(ins, trans_ins->get_operator(), x);
auto c = m.insert_instruction(ins, make_op("contiguous"), t);
return m.insert_instruction(ins, reshaper_ins->get_operator(), c);
};
if (x_ins->name() == "mul")
{
x_ins = m.insert_instruction(ins, make_op("mul"), {insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])});
}
auto y_ins = m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)});
m.replace_instruction(reshaper_ins, y_ins);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -844,6 +876,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -844,6 +876,7 @@ void simplify_reshapes::apply(module& m) const
find_transpose_slice{}, find_transpose_slice{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}, find_transpose_contiguous_reshaper_unary{},
find_mul_add_transpose_contiguous_reshaper_gemm{},
find_reshape_gemm{}); find_reshape_gemm{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment