Commit f1c131de authored by Paul's avatar Paul
Browse files

Format

parent 6f8d1427
...@@ -666,15 +666,12 @@ struct find_slice_transpose ...@@ -666,15 +666,12 @@ struct find_slice_transpose
struct find_dot_transpose struct find_dot_transpose
{ {
auto matcher() const auto matcher() const { return match::name("transpose")(match::args(match::name("dot"))); }
{
return match::name("transpose")(match::args(match::name("dot")));
}
template<class Vector> template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j) static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{ {
if (i >= perm.size() or j >= perm.size()) if(i >= perm.size() or j >= perm.size())
return false; return false;
auto perm2 = perm; auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0); std::iota(perm2.begin(), perm2.end(), 0);
...@@ -682,7 +679,7 @@ struct find_dot_transpose ...@@ -682,7 +679,7 @@ struct find_dot_transpose
return perm2 == perm; return perm2 == perm;
} }
template<class Vector> template <class Vector>
static std::size_t get_batch_elements(const Vector& v) static std::size_t get_batch_elements(const Vector& v)
{ {
return std::accumulate(v.begin(), v.end() - 2, 1, std::multiplies<>{}); return std::accumulate(v.begin(), v.end() - 2, 1, std::multiplies<>{});
...@@ -698,7 +695,7 @@ struct find_dot_transpose ...@@ -698,7 +695,7 @@ struct find_dot_transpose
auto perm = transpose.dims; auto perm = transpose.dims;
auto last = perm.size() - 1; auto last = perm.size() - 1;
// Row/column swapped // Row/column swapped
if (is_swapped(perm, last - 1, last)) if(is_swapped(perm, last - 1, last))
{ {
// Parameters are transposed and flipped // Parameters are transposed and flipped
auto am_t = m.insert_instruction(ins, transpose, bm); auto am_t = m.insert_instruction(ins, transpose, bm);
...@@ -706,9 +703,9 @@ struct find_dot_transpose ...@@ -706,9 +703,9 @@ struct find_dot_transpose
auto new_dot = m.insert_instruction(ins, dot->get_operator(), am_t, bm_t); auto new_dot = m.insert_instruction(ins, dot->get_operator(), am_t, bm_t);
m.replace_instruction(dot, new_dot); m.replace_instruction(dot, new_dot);
} }
else if (is_swapped(perm, last - 1, last - 2)) else if(is_swapped(perm, last - 1, last - 2))
{ {
if (get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2]) if(get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2])
return; return;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment