Commit 6f8d1427 authored by Paul's avatar Paul
Browse files

Add find_dot_transpose

parent f2531606
...@@ -664,6 +664,56 @@ struct find_slice_transpose ...@@ -664,6 +664,56 @@ struct find_slice_transpose
} }
}; };
struct find_dot_transpose
{
auto matcher() const
{
return match::name("transpose")(match::args(match::name("dot")));
}
template<class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if (i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
template<class Vector>
static std::size_t get_batch_elements(const Vector& v)
{
return std::accumulate(v.begin(), v.end() - 2, 1, std::multiplies<>{});
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = ins->inputs().front();
auto am = ins->inputs().front();
auto bm = ins->inputs().front();
auto transpose = any_cast<op::transpose>(ins->get_operator());
auto perm = transpose.dims;
auto last = perm.size() - 1;
// Row/column swapped
if (is_swapped(perm, last - 1, last))
{
// Parameters are transposed and flipped
auto am_t = m.insert_instruction(ins, transpose, bm);
auto bm_t = m.insert_instruction(ins, transpose, am);
auto new_dot = m.insert_instruction(ins, dot->get_operator(), am_t, bm_t);
m.replace_instruction(dot, new_dot);
}
else if (is_swapped(perm, last - 1, last - 2))
{
if (get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2])
return;
}
}
};
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment