Commit dcb15c2b authored by Paul's avatar Paul
Browse files

Fix transpose issues in simplify_reshapes

parent 1f0167bb
......@@ -2,6 +2,7 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
......@@ -42,6 +43,33 @@ instruction_ref find_transpose_input(instruction_ref ins)
return ins;
}
auto get_transpose_dims(instruction_ref ins)
{
return any_cast<const op::transpose&>(ins->get_operator()).dims;
}
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0;i <dims.size();i++)
{
result[i] = dims[permutation[i]];
}
return result;
}
bool is_no_transpose(const std::vector<int64_t>& dims)
{
if (dims.empty())
return true;
if (dims.front() != 0)
return false;
return std::adjacent_find(dims.begin(), dims.end(), [](auto x, auto y) {
return (y - x) != 1;
}) == dims.end();
}
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
......@@ -89,14 +117,24 @@ void simplify_reshapes::apply(program& p) const
continue;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
{
dims = reorder_dims(get_transpose_dims(t), dims);
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
if (is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
}
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
}
}
......
......@@ -165,4 +165,52 @@ TEST_CASE(transpose_double_contiguous)
EXPECT(p.has_instruction(t));
}
TEST_CASE(transpose_partial1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(transpose_partial2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
TEST_CASE(transpose_partial3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment