Commit 88f4aad8 authored by Paul's avatar Paul
Browse files

Merge branch 'simplify-reshapes' into stage

parents fd75cf5f 851ad0e4
...@@ -358,6 +358,17 @@ struct contiguous ...@@ -358,6 +358,17 @@ struct contiguous
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
return {t, lens}; return {t, lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
}; };
struct concat struct concat
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Reshapers that can't handle nonstandard input shapes bool is_reshaper(instruction_ref ins)
bool is_nonstandard_reshaper(instruction_ref ins)
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -18,31 +17,39 @@ bool is_nonstandard_reshaper(instruction_ref ins) ...@@ -18,31 +17,39 @@ bool is_nonstandard_reshaper(instruction_ref ins)
"contiguous" "contiguous"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous"; return contains(names, ins->name());
} }
bool is_reshaper(instruction_ref ins) bool is_transpose_output(instruction_ref ins)
{ {
// clang-format off if(ins->outputs().size() != 1)
static const std::unordered_set<std::string> names = { return false;
"reshape", if(ins->outputs().front()->name() == "contiguous")
"transpose", return is_transpose_output(ins->outputs().front());
// "broadcast", return ins->outputs().front()->name() == "transpose";
"contiguous" }
};
// clang-format on instruction_ref find_transpose_input(instruction_ref ins)
return contains(names, ins->name()) and not is_nonstandard_reshaper(ins); {
if(ins->inputs().size() != 1)
return ins;
if(ins->inputs().front()->name() == "contiguous")
return find_transpose_input(ins->inputs().front());
if(ins->inputs().front()->name() == "transpose")
return ins->inputs().front();
return ins;
} }
void simplify_reshapes::apply(program& p) const void simplify_reshapes::apply(program& p) const
{ {
auto end = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(not is_reshaper(ins)) if(ins->outputs().empty() and ins != end)
continue;
if(ins->outputs().size() != 1)
continue; continue;
if(is_reshaper(ins->outputs().front())) if(is_reshaper(ins))
{
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue; continue;
// Gather reshapes // Gather reshapes
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
...@@ -71,6 +78,22 @@ void simplify_reshapes::apply(program& p) const ...@@ -71,6 +78,22 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(r.first, r.second); p.replace_instruction(r.first, r.second);
} }
} }
else if(ins->name() == "transpose")
{
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front());
}
}
// Replace all reshapes with as_shape // Replace all reshapes with as_shape
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -287,14 +287,7 @@ struct cpu_contiguous ...@@ -287,14 +287,7 @@ struct cpu_contiguous
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
assert(output_shape.standard()); return op.compute(output_shape, std::move(args));
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
} }
}; };
......
...@@ -27,9 +27,9 @@ TEST_CASE(double_contig) ...@@ -27,9 +27,9 @@ TEST_CASE(double_contig)
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 4);
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == get_2x2()); EXPECT(result != get_2x2());
} }
TEST_CASE(double_transpose) TEST_CASE(double_transpose)
...@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -95,7 +95,6 @@ TEST_CASE(double_transpose_sin_pass)
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
// std::cout << p << std::endl;
// TODO: Fix this // TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1); // EXPECT(std::distance(p.begin(), p.end()) == 1);
auto result = p.eval({}); auto result = p.eval({});
...@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose) ...@@ -134,4 +133,36 @@ TEST_CASE(reshape_transpose)
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
} }
TEST_CASE(transpose_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
TEST_CASE(transpose_double_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment