Commit f1c131de authored by Paul's avatar Paul
Browse files

Format

parent 6f8d1427
......@@ -666,15 +666,12 @@ struct find_slice_transpose
struct find_dot_transpose
{
auto matcher() const
{
return match::name("transpose")(match::args(match::name("dot")));
}
auto matcher() const { return match::name("transpose")(match::args(match::name("dot"))); }
template<class Vector>
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if (i >= perm.size() or j >= perm.size())
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
......@@ -682,7 +679,7 @@ struct find_dot_transpose
return perm2 == perm;
}
template<class Vector>
template <class Vector>
static std::size_t get_batch_elements(const Vector& v)
{
return std::accumulate(v.begin(), v.end() - 2, 1, std::multiplies<>{});
......@@ -690,25 +687,25 @@ struct find_dot_transpose
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = ins->inputs().front();
auto am = ins->inputs().front();
auto bm = ins->inputs().front();
auto ins = r.result;
auto dot = ins->inputs().front();
auto am = ins->inputs().front();
auto bm = ins->inputs().front();
auto transpose = any_cast<op::transpose>(ins->get_operator());
auto perm = transpose.dims;
auto last = perm.size() - 1;
auto perm = transpose.dims;
auto last = perm.size() - 1;
// Row/column swapped
if (is_swapped(perm, last - 1, last))
if(is_swapped(perm, last - 1, last))
{
// Parameters are transposed and flipped
auto am_t = m.insert_instruction(ins, transpose, bm);
auto bm_t = m.insert_instruction(ins, transpose, am);
auto am_t = m.insert_instruction(ins, transpose, bm);
auto bm_t = m.insert_instruction(ins, transpose, am);
auto new_dot = m.insert_instruction(ins, dot->get_operator(), am_t, bm_t);
m.replace_instruction(dot, new_dot);
}
else if (is_swapped(perm, last - 1, last - 2))
else if(is_swapped(perm, last - 1, last - 2))
{
if (get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2])
if(get_batch_elements(ins->get_shape().lens()) != ins->get_shape().lens()[last - 2])
return;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment