Commit 45bdaf27 authored by Paul's avatar Paul
Browse files

Format

parent 2262efe0
......@@ -33,42 +33,45 @@ namespace {
MIGRAPHX_PRED_MATCHER(col_matrix, instruction_ref ins)
{
if (not ins->get_shape().transposed())
if(not ins->get_shape().transposed())
return false;
if (ins->get_shape().ndim() < 2)
if(ins->get_shape().ndim() < 2)
return false;
auto perm = find_permutation(ins->get_shape());
auto n = perm.size() - 1;
return perm[n] == n - 1 and perm[n-1] == n;
return perm[n] == n - 1 and perm[n - 1] == n;
}
MIGRAPHX_PRED_MATCHER(broadcast_matrix_dims, instruction_ref ins)
{
if (not ins->get_shape().broadcasted())
if(not ins->get_shape().broadcasted())
return false;
if (ins->get_shape().ndim() < 2)
if(ins->get_shape().ndim() < 2)
return false;
return std::any_of(ins->get_shape().lens().rbegin(), ins->get_shape().lens().rend()+2, [](auto i) {
return i == 0;
});
return std::any_of(ins->get_shape().lens().rbegin(),
ins->get_shape().lens().rend() + 2,
[](auto i) { return i == 0; });
}
struct find_dot_const
{
auto matcher() const
{
return match::name("dot")(match::arg(1)(match::is_constant(), match::none_of(col_matrix(), broadcast_matrix_dims()), match::skip_broadcasts(match::any().bind("w"))))(match::none_of(match::is_constant()));
return match::name("dot")(match::arg(1)(
match::is_constant(),
match::none_of(col_matrix(), broadcast_matrix_dims()),
match::skip_broadcasts(match::any().bind("w"))))(match::none_of(match::is_constant()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto w = r.instructions["w"];
if (w->get_shape().ndim() < 2)
if(w->get_shape().ndim() < 2)
return;
auto perm = find_permutation(w->get_shape());
auto n = perm.size() - 1;
std::swap(perm[n], perm[n-1]);
std::swap(perm[n], perm[n - 1]);
auto wl = m.insert_instruction(std::next(w), make_op("layout", {{"permutation", perm}}), w);
m.replace_instruction(w, wl);
}
......@@ -76,10 +79,7 @@ struct find_dot_const
} // namespace
void rewrite_ops::apply(module& m) const
{
match::find_matches(m, find_dot_const{});
}
void rewrite_ops::apply(module& m) const { match::find_matches(m, find_dot_const{}); }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment