Commit 5f7826da authored by Paul's avatar Paul
Browse files

Formatting

parent dcb15c2b
...@@ -52,22 +52,21 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t ...@@ -52,22 +52,21 @@ std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t
{ {
std::vector<int64_t> result(dims.size()); std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size()); assert(dims.size() == permutation.size());
for(std::size_t i = 0;i <dims.size();i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
result[i] = dims[permutation[i]]; result[i] = dims[permutation[i]];
} }
return result; return result;
} }
bool is_no_transpose(const std::vector<int64_t>& dims) bool is_no_transpose(const std::vector<int64_t>& dims)
{ {
if (dims.empty()) if(dims.empty())
return true; return true;
if (dims.front() != 0) if(dims.front() != 0)
return false; return false;
return std::adjacent_find(dims.begin(), dims.end(), [](auto x, auto y) { return std::adjacent_find(
return (y - x) != 1; dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}) == dims.end();
} }
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
...@@ -122,12 +121,12 @@ void simplify_reshapes::apply(program& p) const ...@@ -122,12 +121,12 @@ void simplify_reshapes::apply(program& p) const
do do
{ {
dims = reorder_dims(get_transpose_dims(t), dims); dims = reorder_dims(get_transpose_dims(t), dims);
x = t; x = t;
t = find_transpose_input(x); t = find_transpose_input(x);
} while(x != t and t->name() == "transpose"); } while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose") if(t == ins or t->name() != "transpose")
continue; continue;
if (is_no_transpose(dims)) if(is_no_transpose(dims))
{ {
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
......
...@@ -170,8 +170,8 @@ TEST_CASE(transpose_partial1) ...@@ -170,8 +170,8 @@ TEST_CASE(transpose_partial1)
migraphx::program p; migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
...@@ -185,9 +185,9 @@ TEST_CASE(transpose_partial2) ...@@ -185,9 +185,9 @@ TEST_CASE(transpose_partial2)
migraphx::program p; migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3); p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
...@@ -201,10 +201,10 @@ TEST_CASE(transpose_partial3) ...@@ -201,10 +201,10 @@ TEST_CASE(transpose_partial3)
migraphx::program p; migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3); auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4); p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape(); auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment