Unverified Commit 2fdf510d authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into onnx_autopad_fix

parents d5939189 4085af9b
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
p1.compile(migraphx::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{}); p2.compile(migraphx::cpu::target{});
...@@ -93,10 +93,10 @@ TEST_CASE(non_literal) ...@@ -93,10 +93,10 @@ TEST_CASE(non_literal)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(any_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
} }
TEST_CASE(as_literal) TEST_CASE(as_literal)
...@@ -121,7 +121,7 @@ TEST_CASE(as_literal) ...@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
...@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape) ...@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
migraphx::program p1 = create_program(); migraphx::program p1 = create_program();
migraphx::program p2 = create_program(); migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt; migraphx::rewrite_batchnorm opt;
opt.apply(p2); opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm)); EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm)); EXPECT(none_of(p2, &is_batch_norm));
......
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3) ...@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3)
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, x); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, one, two); auto sum2 = p2.add_instruction(migraphx::op::add{}, one, sum1);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); auto sum3 = p2.add_instruction(migraphx::op::add{}, x, sum2);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -129,4 +132,73 @@ void simplify_add4() ...@@ -129,4 +132,73 @@ void simplify_add4()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(simplify_mul_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
p.compile(simplify_algebra_target{});
auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
TEST_CASE(simplify_mul_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, x);
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto mul1 = p2.add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = p2.add_instruction(migraphx::op::mul{}, two, one);
auto sum = p2.add_instruction(migraphx::op::add{}, mul1, mul2);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_inner_broadcast)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto xb = p1.add_instruction(b, x);
auto yb = p1.add_instruction(b, y);
auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb);
p1.add_instruction(pass_op{}, sum);
}
p1.compile(simplify_algebra_target{});
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
auto sumb = p2.add_instruction(b, sum);
p2.add_instruction(pass_op{}, sumb);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -48,6 +48,22 @@ TEST_CASE(add_bcast_test) ...@@ -48,6 +48,22 @@ TEST_CASE(add_bcast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(assert_less_equal_test)
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", s0);
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
auto l2 = p.add_literal(l);
p.add_instruction(migraphx::op::add{}, l0, l1);
auto l3 = p.add_instruction(migraphx::op::identity{}, l0, l1);
p.add_instruction(migraphx::op::identity{}, l3, l2);
auto prog = optimize_tf("assert_less_equal_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(batchmatmul_test) TEST_CASE(batchmatmul_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -100,6 +116,16 @@ TEST_CASE(biasadd_test) ...@@ -100,6 +116,16 @@ TEST_CASE(biasadd_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(concat_test) TEST_CASE(concat_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -118,16 +144,6 @@ TEST_CASE(concat_test) ...@@ -118,16 +144,6 @@ TEST_CASE(concat_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(const_test) TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -271,9 +287,10 @@ TEST_CASE(mean_test_nhwc) ...@@ -271,9 +287,10 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::reduce_mean op{{2, 3}}; auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l3 = p.add_instruction(op, l0); migraphx::op::reduce_mean op{{1, 2}};
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
auto prog = optimize_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -291,6 +308,23 @@ TEST_CASE(mul_test) ...@@ -291,6 +308,23 @@ TEST_CASE(mul_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(onehot_test)
{
migraphx::program p;
auto l0 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2);
p.add_literal(1.0f);
p.add_literal(0.0f);
auto l1 = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_tf("onehot_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test) TEST_CASE(pack_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -475,20 +509,44 @@ TEST_CASE(stridedslice_test) ...@@ -475,20 +509,44 @@ TEST_CASE(stridedslice_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5}; op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
auto l1 = p.add_instruction(op, l0); auto l2 = p.add_instruction(op, l1);
auto shrink_axis = 1; auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1); p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
auto prog = optimize_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(stridedslice_masks_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 1, 1, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(sub_test) TEST_CASE(sub_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment