Commit 9ea01307 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove the keep dim attribute from argmax and argmin operators

parent cf984059
......@@ -19,12 +19,11 @@ namespace op {
struct argmax
{
int axis = 0;
int keep_dims = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.keep_dims, "keep_dims"));
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmax"; }
......@@ -40,10 +39,6 @@ struct argmax
}
lens[axis] = 1;
if(keep_dims == 0)
{
lens.erase(lens.begin() + axis);
}
return {shape::int64_type, lens};
}
......
......@@ -19,12 +19,11 @@ namespace op {
struct argmin
{
int axis = 0;
int keep_dims = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.keep_dims, "keep_dims"));
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmin"; }
......@@ -40,10 +39,6 @@ struct argmin
}
lens[axis] = 1;
if(keep_dims == 0)
{
lens.erase(lens.begin() + axis);
}
return {shape::int64_type, lens};
}
......
......@@ -284,7 +284,15 @@ struct onnx_parser
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
return prog.add_instruction(op::argmax{axis, keep_dims}, std::move(args));
if (keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
......@@ -303,7 +311,15 @@ struct onnx_parser
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
return prog.add_instruction(op::argmin{axis, keep_dims}, std::move(args));
if (keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
}
instruction_ref
......
......@@ -1135,8 +1135,7 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT(migraphx::verify_range(results_vector, s));
}
template <int KeepDims>
void argmax_test_0()
TEST_CASE(argmax_test_0)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
......@@ -1145,7 +1144,7 @@ void argmax_test_0()
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{0, KeepDims}, dl);
p.add_instruction(migraphx::op::argmax{0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
......@@ -1154,9 +1153,6 @@ void argmax_test_0()
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_00) { argmax_test_0<0>(); }
TEST_CASE(argmax_test_01) { argmax_test_0<1>(); }
TEST_CASE(argmax_test_1)
{
migraphx::program p;
......@@ -1166,7 +1162,7 @@ TEST_CASE(argmax_test_1)
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{1, 0}, dl);
p.add_instruction(migraphx::op::argmax{1}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
......@@ -1184,7 +1180,7 @@ TEST_CASE(argmax_test_2)
std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{2, 0}, dl);
p.add_instruction(migraphx::op::argmax{2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
......@@ -1193,8 +1189,7 @@ TEST_CASE(argmax_test_2)
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
template <int KeepDims>
void argmin_test_0()
TEST_CASE(argmin_test_0)
{
migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
......@@ -1203,7 +1198,7 @@ void argmin_test_0()
std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{0, KeepDims}, dl);
p.add_instruction(migraphx::op::argmin{0}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
......@@ -1212,9 +1207,6 @@ void argmin_test_0()
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_00) { argmin_test_0<0>(); }
TEST_CASE(argmin_test_01) { argmin_test_0<1>(); }
TEST_CASE(argmin_test_1)
{
migraphx::program p;
......@@ -1224,7 +1216,7 @@ TEST_CASE(argmin_test_1)
std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{1, 0}, dl);
p.add_instruction(migraphx::op::argmin{1}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
......@@ -1242,7 +1234,7 @@ TEST_CASE(argmin_test_2)
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{2, 0}, dl);
p.add_instruction(migraphx::op::argmin{2}, dl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int64_t> result_vec;
......
......@@ -611,31 +611,29 @@ template struct test_softmax<1>;
template struct test_softmax<2>;
template struct test_softmax<3>;
template <class T, int Axis, int KeepDims>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, KeepDims>>
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis, KeepDims}, param);
p.add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0, 0>;
template struct test_arg_ops<migraphx::op::argmax, 0, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>;
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0, 0>;
template struct test_arg_ops<migraphx::op::argmin, 0, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
struct test_conv : verify_program<test_conv>
{
......
......@@ -788,7 +788,8 @@ TEST_CASE(argmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::argmax{2, 0}, l0);
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = migraphx::parse_onnx("argmax_test.onnx");
EXPECT(p == prog);
......@@ -798,7 +799,8 @@ TEST_CASE(argmin)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::argmin{3, 0}, l0);
auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = migraphx::parse_onnx("argmin_test.onnx");
EXPECT(p == prog);
......
......@@ -385,47 +385,27 @@ void test_argop_var()
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0, 1}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1, 1}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2, 1}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3, 1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {3, 4, 5}}, T{0, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 4, 5}}, T{1, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 5}}, T{2, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4}}, T{3, 0}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4, 1}, input);
throws_shape(T{4}, input);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment