Commit 45bdaf27 authored by Paul's avatar Paul
Browse files

Format

parent 2262efe0
...@@ -33,42 +33,45 @@ namespace { ...@@ -33,42 +33,45 @@ namespace {
MIGRAPHX_PRED_MATCHER(col_matrix, instruction_ref ins) MIGRAPHX_PRED_MATCHER(col_matrix, instruction_ref ins)
{ {
if (not ins->get_shape().transposed()) if(not ins->get_shape().transposed())
return false; return false;
if (ins->get_shape().ndim() < 2) if(ins->get_shape().ndim() < 2)
return false; return false;
auto perm = find_permutation(ins->get_shape()); auto perm = find_permutation(ins->get_shape());
auto n = perm.size() - 1; auto n = perm.size() - 1;
return perm[n] == n - 1 and perm[n-1] == n; return perm[n] == n - 1 and perm[n - 1] == n;
} }
MIGRAPHX_PRED_MATCHER(broadcast_matrix_dims, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_matrix_dims, instruction_ref ins)
{ {
if (not ins->get_shape().broadcasted()) if(not ins->get_shape().broadcasted())
return false; return false;
if (ins->get_shape().ndim() < 2) if(ins->get_shape().ndim() < 2)
return false; return false;
return std::any_of(ins->get_shape().lens().rbegin(), ins->get_shape().lens().rend()+2, [](auto i) { return std::any_of(ins->get_shape().lens().rbegin(),
return i == 0; ins->get_shape().lens().rend() + 2,
}); [](auto i) { return i == 0; });
} }
struct find_dot_const struct find_dot_const
{ {
auto matcher() const auto matcher() const
{ {
return match::name("dot")(match::arg(1)(match::is_constant(), match::none_of(col_matrix(), broadcast_matrix_dims()), match::skip_broadcasts(match::any().bind("w"))))(match::none_of(match::is_constant())); return match::name("dot")(match::arg(1)(
match::is_constant(),
match::none_of(col_matrix(), broadcast_matrix_dims()),
match::skip_broadcasts(match::any().bind("w"))))(match::none_of(match::is_constant()));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto w = r.instructions["w"]; auto w = r.instructions["w"];
if (w->get_shape().ndim() < 2) if(w->get_shape().ndim() < 2)
return; return;
auto perm = find_permutation(w->get_shape()); auto perm = find_permutation(w->get_shape());
auto n = perm.size() - 1; auto n = perm.size() - 1;
std::swap(perm[n], perm[n-1]); std::swap(perm[n], perm[n - 1]);
auto wl = m.insert_instruction(std::next(w), make_op("layout", {{"permutation", perm}}), w); auto wl = m.insert_instruction(std::next(w), make_op("layout", {{"permutation", perm}}), w);
m.replace_instruction(w, wl); m.replace_instruction(w, wl);
} }
...@@ -76,10 +79,7 @@ struct find_dot_const ...@@ -76,10 +79,7 @@ struct find_dot_const
} // namespace } // namespace
void rewrite_ops::apply(module& m) const void rewrite_ops::apply(module& m) const { match::find_matches(m, find_dot_const{}); }
{
match::find_matches(m, find_dot_const{});
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment