Unverified Commit 015631a1 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #338 from ROCmSoftwarePlatform/eliminate-more-contiguous

Eliminate more contiguous
parents a1c7e7a5 f1de9bc1
......@@ -4,6 +4,8 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <utility>
namespace migraphx {
......@@ -82,6 +84,14 @@ void eliminate_contiguous::apply(program& p) const
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
}
}
}
}
......
......@@ -87,6 +87,7 @@ struct miopen_apply
void init()
{
this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
......
......@@ -22,7 +22,7 @@ struct eliminate_contiguous_target
TEST_CASE(standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
......@@ -31,18 +31,40 @@ TEST_CASE(standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(non_standard_op)
TEST_CASE(standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(non_standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(non_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(transpose_gemm)
{
migraphx::program p;
......@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm)
TEST_CASE(transpose_standard_op)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
......@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count);
}
TEST_CASE(transpose_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 3);
}
TEST_CASE(no_packed_unary_op)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment