"model/models/vscode:/vscode.git/clone" did not exist on "7d25b9e194f106e9c2a5289dfde40077c0838b7d"
Unverified Commit 0dc15ece authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into parse-indices

parents bfed9064 015631a1
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -82,6 +84,14 @@ void eliminate_contiguous::apply(program& p) const ...@@ -82,6 +84,14 @@ void eliminate_contiguous::apply(program& p) const
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
}
} }
} }
} }
......
...@@ -87,6 +87,7 @@ struct miopen_apply ...@@ -87,6 +87,7 @@ struct miopen_apply
void init() void init()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end())); this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_abs>("abs", make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
......
...@@ -22,7 +22,7 @@ struct eliminate_contiguous_target ...@@ -22,7 +22,7 @@ struct eliminate_contiguous_target
TEST_CASE(standard_op) TEST_CASE(standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c); p.add_instruction(pass_standard_op{}, c);
...@@ -31,18 +31,40 @@ TEST_CASE(standard_op) ...@@ -31,18 +31,40 @@ TEST_CASE(standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(non_standard_op) TEST_CASE(standard_op_const)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_standard_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(non_standard_op)
{
migraphx::program p;
auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c); p.add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{}); p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(non_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 2);
}
TEST_CASE(transpose_gemm) TEST_CASE(transpose_gemm)
{ {
migraphx::program p; migraphx::program p;
...@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm) ...@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm)
TEST_CASE(transpose_standard_op) TEST_CASE(transpose_standard_op)
{ {
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t); auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c); auto sn = p.add_instruction(migraphx::op::sin{}, c);
...@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op) ...@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op)
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
TEST_CASE(transpose_standard_op_const)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sn = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sn);
p.compile(eliminate_contiguous_target{});
EXPECT(std::distance(p.begin(), p.end()) == 3);
}
TEST_CASE(no_packed_unary_op) TEST_CASE(no_packed_unary_op)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment