Commit f1c131de authored by Paul's avatar Paul
Browse files

Format

parent 6f8d1427
......@@ -666,15 +666,12 @@ struct find_slice_transpose
struct find_dot_transpose
{
auto matcher() const
{
return match::name("transpose")(match::args(match::name("dot")));
}
auto matcher() const { return match::name("transpose")(match::args(match::name("dot"))); }
template<class Vector>
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if (i >= perm.size() or j >= perm.size())
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
......@@ -682,7 +679,7 @@ struct find_dot_transpose
return perm2 == perm;
}
template<class Vector>
template <class Vector>
static std::size_t get_batch_elements(const Vector& v)
{
return std::accumulate(v.begin(), v.end() - 2, 1, std::multiplies<>{});
......@@ -698,7 +695,7 @@ struct find_dot_transpose
auto perm = transpose.dims;
auto last = perm.size() - 1;
// Row/column swapped
if (is_swapped(perm, last - 1, last))
if(is_swapped(perm, last - 1, last))
{
// Parameters are transposed and flipped
auto am_t = m.insert_instruction(ins, transpose, bm);
......@@ -706,9 +703,9 @@ struct find_dot_transpose
auto new_dot = m.insert_instruction(ins, dot->get_operator(), am_t, bm_t);
m.replace_instruction(dot, new_dot);
}
else if (is_swapped(perm, last - 1, last - 2))
else if(is_swapped(perm, last - 1, last - 2))
{
if (get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2])
if(get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2])
return;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment