Commit 7110eb0e authored by Paul's avatar Paul
Browse files

Format

parent 49dc6d12
......@@ -587,15 +587,15 @@ struct find_reshape_cont
};
// match sequence of transpose --> contiguous --> reshaper_op
template<class... Ms>
template <class... Ms>
auto match_transpose_contiguous_reshaper(Ms... ms)
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape(ms...).bind("trans_ins")))
.bind("cont_ins")))
match::args(match::name("contiguous")(
match::used_once(),
match::args(match::transpose_shape(ms...).bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};
......@@ -631,29 +631,36 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm
{
auto matcher() const
{
auto pw = match::name("mul", "add")(match::used_once(), match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x")));
return match::name("dot")(match::either_arg(0, 1)(match_transpose_contiguous_reshaper(pw.bind("pointwise")), match::is_constant()));
auto pw = match::name("mul", "add")(
match::used_once(),
match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x")));
return match::name("dot")(match::either_arg(0, 1)(
match_transpose_contiguous_reshaper(pw.bind("pointwise")), match::is_constant()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto x_ins = r.instructions["x"];
auto c_ins = r.instructions["c"];
auto pw_ins = r.instructions["pointwise"];
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto x_ins = r.instructions["x"];
auto c_ins = r.instructions["c"];
auto pw_ins = r.instructions["pointwise"];
auto insert_reshapes = [&](auto x) {
auto t = m.insert_instruction(ins, trans_ins->get_operator(), x);
auto c = m.insert_instruction(ins, make_op("contiguous"), t);
return m.insert_instruction(ins, reshaper_ins->get_operator(), c);
};
if (x_ins->name() == "mul")
if(x_ins->name() == "mul")
{
x_ins = m.insert_instruction(ins, make_op("mul"), {insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])});
x_ins = m.insert_instruction(
ins,
make_op("mul"),
{insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])});
}
auto y_ins = m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)});
auto y_ins =
m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)});
m.replace_instruction(reshaper_ins, y_ins);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment