Commit 3c45f2ed authored by Paul's avatar Paul
Browse files

Formatting

parent 9b8d62d1
...@@ -22,18 +22,18 @@ bool is_reshaper(instruction_ref ins) ...@@ -22,18 +22,18 @@ bool is_reshaper(instruction_ref ins)
bool is_transpose_output(instruction_ref ins) bool is_transpose_output(instruction_ref ins)
{ {
if (ins->outputs().size() != 1) if(ins->outputs().size() != 1)
return false; return false;
if (ins->outputs().front()->name() == "contiguous") if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front()); return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose"; return ins->outputs().front()->name() == "transpose";
} }
instruction_ref find_transpose_input(instruction_ref ins) instruction_ref find_transpose_input(instruction_ref ins)
{ {
if (ins->inputs().size() != 1) if(ins->inputs().size() != 1)
return ins; return ins;
if (ins->inputs().front()->name() == "contiguous") if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front()); return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose") if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front(); return ins->inputs().front();
...@@ -47,7 +47,7 @@ void simplify_reshapes::apply(program& p) const ...@@ -47,7 +47,7 @@ void simplify_reshapes::apply(program& p) const
{ {
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
if(is_reshaper(ins)) if(is_reshaper(ins))
{ {
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper)) if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue; continue;
...@@ -78,9 +78,9 @@ void simplify_reshapes::apply(program& p) const ...@@ -78,9 +78,9 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second); p.replace_instruction(r.first, r.second);
} }
} }
else if (ins->name() == "transpose") else if(ins->name() == "transpose")
{ {
if (is_transpose_output(ins)) if(is_transpose_output(ins))
continue; continue;
auto x = ins; auto x = ins;
auto t = ins; auto t = ins;
...@@ -89,7 +89,7 @@ void simplify_reshapes::apply(program& p) const ...@@ -89,7 +89,7 @@ void simplify_reshapes::apply(program& p) const
x = t; x = t;
t = find_transpose_input(x); t = find_transpose_input(x);
} while(x != t and t->name() == "transpose"); } while(x != t and t->name() == "transpose");
if (t == ins or t->name() != "transpose") if(t == ins or t->name() != "transpose")
continue; continue;
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
......
...@@ -142,7 +142,7 @@ TEST_CASE(transpose_contiguous) ...@@ -142,7 +142,7 @@ TEST_CASE(transpose_contiguous)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1); p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
...@@ -158,7 +158,7 @@ TEST_CASE(transpose_double_contiguous) ...@@ -158,7 +158,7 @@ TEST_CASE(transpose_double_contiguous)
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment