Commit 7110eb0e authored by Paul's avatar Paul
Browse files

Format

parent 49dc6d12
...@@ -587,14 +587,14 @@ struct find_reshape_cont ...@@ -587,14 +587,14 @@ struct find_reshape_cont
}; };
// match sequence of transpose --> contiguous --> reshaper_op // match sequence of transpose --> contiguous --> reshaper_op
template<class... Ms> template <class... Ms>
auto match_transpose_contiguous_reshaper(Ms... ms) auto match_transpose_contiguous_reshaper(Ms... ms)
{ {
return match::name({"reshape", "squeeze", "unsqueeze"})( return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(), match::used_once(),
match::args( match::args(match::name("contiguous")(
match::name("contiguous")( match::used_once(),
match::used_once(), match::args(match::transpose_shape(ms...).bind("trans_ins"))) match::args(match::transpose_shape(ms...).bind("trans_ins")))
.bind("cont_ins"))) .bind("cont_ins")))
.bind("reshaper_ins"); .bind("reshaper_ins");
}; };
...@@ -631,8 +631,11 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm ...@@ -631,8 +631,11 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto pw = match::name("mul", "add")(match::used_once(), match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x"))); auto pw = match::name("mul", "add")(
return match::name("dot")(match::either_arg(0, 1)(match_transpose_contiguous_reshaper(pw.bind("pointwise")), match::is_constant())); match::used_once(),
match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x")));
return match::name("dot")(match::either_arg(0, 1)(
match_transpose_contiguous_reshaper(pw.bind("pointwise")), match::is_constant()));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -648,12 +651,16 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm ...@@ -648,12 +651,16 @@ struct find_mul_add_transpose_contiguous_reshaper_gemm
auto c = m.insert_instruction(ins, make_op("contiguous"), t); auto c = m.insert_instruction(ins, make_op("contiguous"), t);
return m.insert_instruction(ins, reshaper_ins->get_operator(), c); return m.insert_instruction(ins, reshaper_ins->get_operator(), c);
}; };
if (x_ins->name() == "mul") if(x_ins->name() == "mul")
{ {
x_ins = m.insert_instruction(ins, make_op("mul"), {insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])}); x_ins = m.insert_instruction(
ins,
make_op("mul"),
{insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])});
} }
auto y_ins = m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)}); auto y_ins =
m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)});
m.replace_instruction(reshaper_ins, y_ins); m.replace_instruction(reshaper_ins, y_ins);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment