Commit 992666e6 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Improve operators for onnxruntime (#405)



* improve unsqueeze to support negative axis and parsing scalar

* clang format

* add a test example for the negative axis of unsqueeze

* improve the squeeze operator to support negative axis

* clang format

* fixed a small bug in the lrn implementation

* clang format

* support negative axis in argmax and argmin

* clang format

* improve flatten to support negative axis

* clang format

* change softmax/logsoftmax to support negative axis

* clang format

* improve transpose by adding default perm

* clang format

* add one more dimens for tensor size

* add one more dimens for tensor size

* disable conv ops fusion for non-symmetric cases

* clang format

* fixed review comments

* move computing axis from the device function to the compute function

* clang format

* move computing axis from device function to the operator computing function

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 2ee0f9e8
......@@ -14,7 +14,9 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis;
device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back();
}
......
......@@ -324,6 +324,7 @@ TEST_CASE(squeeze_test)
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraphx::program p;
std::vector<float> data(4 * 3 * 3);
......@@ -1321,6 +1322,24 @@ TEST_CASE(argmax_test_1)
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_neg_2)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{-2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_2)
{
migraphx::program p;
......@@ -1393,6 +1412,24 @@ TEST_CASE(argmin_test_2)
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_neg_1)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{-1}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(conv2d_test)
{
migraphx::program p;
......
......@@ -744,11 +744,15 @@ template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1>;
template struct test_arg_ops<migraphx::op::argmax, -2>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3>;
template struct test_arg_ops<migraphx::op::argmin, -4>;
struct test_conv : verify_program<test_conv>
{
......
......@@ -50,9 +50,8 @@ TEST_CASE(add_scalar_test)
migraphx::program p;
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
p.add_instruction(migraphx::op::add{}, l0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog);
......@@ -439,14 +438,15 @@ TEST_CASE(gather_test)
TEST_CASE(gemm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{7, 11}}, l2);
auto alpha = 2.f;
auto beta = 2.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1);
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1, bl2);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog);
......@@ -548,9 +548,8 @@ TEST_CASE(implicit_add_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
p.add_instruction(migraphx::op::add{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_add_bcast_test.onnx");
......@@ -562,9 +561,8 @@ TEST_CASE(implicit_pow_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::pow{}, l2, l3);
p.add_instruction(migraphx::op::pow{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_pow_bcast_test.onnx");
......@@ -576,9 +574,8 @@ TEST_CASE(implicit_sub_bcast_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3);
p.add_instruction(migraphx::op::sub{}, l0, l3);
auto prog = migraphx::parse_onnx("implicit_sub_bcast_test.onnx");
......@@ -1028,9 +1025,8 @@ TEST_CASE(sub_scalar_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, m0, m1);
p.add_instruction(migraphx::op::sub{}, l0, m1);
auto prog = migraphx::parse_onnx("sub_scalar_test.onnx");
EXPECT(p == prog);
......
......@@ -102,6 +102,7 @@ TEST_CASE(transpose_shape)
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::op::transpose{{0, 1}}, input);
expect_shape(output, migraphx::op::transpose{{1, 0}}, input);
expect_shape(output, migraphx::op::transpose{}, input);
throws_shape(migraphx::op::transpose{{1, 2}}, input);
}
......@@ -157,9 +158,15 @@ TEST_CASE(flatten_shape)
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraphx::op::flatten{0},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraphx::op::flatten{-4},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
migraphx::op::flatten{1},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
migraphx::op::flatten{-3},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
migraphx::op::flatten{2},
input);
......@@ -170,6 +177,7 @@ TEST_CASE(flatten_shape)
migraphx::op::flatten{4},
input);
throws_shape(migraphx::op::flatten{5}, input);
throws_shape(migraphx::op::flatten{-5}, input);
}
TEST_CASE(slice_shape)
......@@ -482,6 +490,21 @@ TEST_CASE(test_argmin)
}
}
TEST_CASE(test_squeeze)
{
{
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
expect_shape(s2, migraphx::op::squeeze{{-2}}, s1);
}
{
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
expect_shape(s2, migraphx::op::unsqueeze{{-2}}, s1);
}
}
template <class T>
void test_reduce_ops()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment