Commit 9ea01307 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove the keep dim attribute from argmax and argmin operators

parent cf984059
...@@ -19,12 +19,11 @@ namespace op { ...@@ -19,12 +19,11 @@ namespace op {
struct argmax struct argmax
{ {
int axis = 0; int axis = 0;
int keep_dims = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis"), f(self.keep_dims, "keep_dims")); return pack(f(self.axis, "axis"));
} }
std::string name() const { return "argmax"; } std::string name() const { return "argmax"; }
...@@ -40,10 +39,6 @@ struct argmax ...@@ -40,10 +39,6 @@ struct argmax
} }
lens[axis] = 1; lens[axis] = 1;
if(keep_dims == 0)
{
lens.erase(lens.begin() + axis);
}
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
......
...@@ -19,12 +19,11 @@ namespace op { ...@@ -19,12 +19,11 @@ namespace op {
struct argmin struct argmin
{ {
int axis = 0; int axis = 0;
int keep_dims = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis"), f(self.keep_dims, "keep_dims")); return pack(f(self.axis, "axis"));
} }
std::string name() const { return "argmin"; } std::string name() const { return "argmin"; }
...@@ -40,10 +39,6 @@ struct argmin ...@@ -40,10 +39,6 @@ struct argmin
} }
lens[axis] = 1; lens[axis] = 1;
if(keep_dims == 0)
{
lens.erase(lens.begin() + axis);
}
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
......
...@@ -284,7 +284,15 @@ struct onnx_parser ...@@ -284,7 +284,15 @@ struct onnx_parser
keep_dims = parse_value(attributes.at("keepdims")).at<int>(); keep_dims = parse_value(attributes.at("keepdims")).at<int>();
} }
return prog.add_instruction(op::argmax{axis, keep_dims}, std::move(args)); if (keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
} }
instruction_ref parse_argmin(const std::string&, instruction_ref parse_argmin(const std::string&,
...@@ -303,7 +311,15 @@ struct onnx_parser ...@@ -303,7 +311,15 @@ struct onnx_parser
keep_dims = parse_value(attributes.at("keepdims")).at<int>(); keep_dims = parse_value(attributes.at("keepdims")).at<int>();
} }
return prog.add_instruction(op::argmin{axis, keep_dims}, std::move(args)); if (keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
} }
instruction_ref instruction_ref
......
...@@ -1135,8 +1135,7 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1135,8 +1135,7 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
template <int KeepDims> TEST_CASE(argmax_test_0)
void argmax_test_0()
{ {
migraphx::program p; migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
...@@ -1145,7 +1144,7 @@ void argmax_test_0() ...@@ -1145,7 +1144,7 @@ void argmax_test_0()
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1}; std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data}); auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{0, KeepDims}, dl); p.add_instruction(migraphx::op::argmax{0}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -1154,9 +1153,6 @@ void argmax_test_0() ...@@ -1154,9 +1153,6 @@ void argmax_test_0()
EXPECT(migraphx::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_00) { argmax_test_0<0>(); }
TEST_CASE(argmax_test_01) { argmax_test_0<1>(); }
TEST_CASE(argmax_test_1) TEST_CASE(argmax_test_1)
{ {
migraphx::program p; migraphx::program p;
...@@ -1166,7 +1162,7 @@ TEST_CASE(argmax_test_1) ...@@ -1166,7 +1162,7 @@ TEST_CASE(argmax_test_1)
std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2}; std::vector<int64_t> res_gold = {0, 0, 2, 1, 2, 0, 0, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data}); auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{1, 0}, dl); p.add_instruction(migraphx::op::argmax{1}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -1184,7 +1180,7 @@ TEST_CASE(argmax_test_2) ...@@ -1184,7 +1180,7 @@ TEST_CASE(argmax_test_2)
std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3}; std::vector<int64_t> res_gold = {1, 3, 2, 2, 2, 3};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data}); auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmax{2, 0}, dl); p.add_instruction(migraphx::op::argmax{2}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -1193,8 +1189,7 @@ TEST_CASE(argmax_test_2) ...@@ -1193,8 +1189,7 @@ TEST_CASE(argmax_test_2)
EXPECT(migraphx::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify_range(result_vec, res_gold));
} }
template <int KeepDims> TEST_CASE(argmin_test_0)
void argmin_test_0()
{ {
migraphx::program p; migraphx::program p;
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758, std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
...@@ -1203,7 +1198,7 @@ void argmin_test_0() ...@@ -1203,7 +1198,7 @@ void argmin_test_0()
std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0}; std::vector<int64_t> res_gold = {1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data}); auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{0, KeepDims}, dl); p.add_instruction(migraphx::op::argmin{0}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -1212,9 +1207,6 @@ void argmin_test_0() ...@@ -1212,9 +1207,6 @@ void argmin_test_0()
EXPECT(migraphx::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify_range(result_vec, res_gold));
} }
TEST_CASE(argmin_test_00) { argmin_test_0<0>(); }
TEST_CASE(argmin_test_01) { argmin_test_0<1>(); }
TEST_CASE(argmin_test_1) TEST_CASE(argmin_test_1)
{ {
migraphx::program p; migraphx::program p;
...@@ -1224,7 +1216,7 @@ TEST_CASE(argmin_test_1) ...@@ -1224,7 +1216,7 @@ TEST_CASE(argmin_test_1)
std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0}; std::vector<int64_t> res_gold = {2, 2, 0, 2, 0, 1, 2, 0};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data}); auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{1, 0}, dl); p.add_instruction(migraphx::op::argmin{1}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
...@@ -1242,7 +1234,7 @@ TEST_CASE(argmin_test_2) ...@@ -1242,7 +1234,7 @@ TEST_CASE(argmin_test_2)
std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2}; std::vector<int64_t> res_gold = {2, 1, 0, 3, 3, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape data_shape{migraphx::shape::float_type, {2, 3, 4}};
auto dl = p.add_literal(migraphx::literal{data_shape, data}); auto dl = p.add_literal(migraphx::literal{data_shape, data});
p.add_instruction(migraphx::op::argmin{2, 0}, dl); p.add_instruction(migraphx::op::argmin{2}, dl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
......
...@@ -611,31 +611,29 @@ template struct test_softmax<1>; ...@@ -611,31 +611,29 @@ template struct test_softmax<1>;
template struct test_softmax<2>; template struct test_softmax<2>;
template struct test_softmax<3>; template struct test_softmax<3>;
template <class T, int Axis, int KeepDims> template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, KeepDims>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s); auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis, KeepDims}, param); p.add_instruction(T{Axis}, param);
return p; return p;
} }
}; };
template struct test_arg_ops<migraphx::op::argmax, 0, 0>; template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 0, 1>; template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>; template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>; template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>;
template struct test_arg_ops<migraphx::op::argmin, 0, 0>; template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 0, 1>; template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>; template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>; template struct test_arg_ops<migraphx::op::argmin, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
......
...@@ -788,7 +788,8 @@ TEST_CASE(argmax) ...@@ -788,7 +788,8 @@ TEST_CASE(argmax)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::argmax{2, 0}, l0); auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = migraphx::parse_onnx("argmax_test.onnx"); auto prog = migraphx::parse_onnx("argmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -798,7 +799,8 @@ TEST_CASE(argmin) ...@@ -798,7 +799,8 @@ TEST_CASE(argmin)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::argmin{3, 0}, l0); auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = migraphx::parse_onnx("argmin_test.onnx"); auto prog = migraphx::parse_onnx("argmin_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -385,47 +385,27 @@ void test_argop_var() ...@@ -385,47 +385,27 @@ void test_argop_var()
{ {
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0, 1}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1, 1}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2, 1}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3, 1}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {3, 4, 5}}, T{0, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 4, 5}}, T{1, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 5}}, T{2, 0}, input);
}
{
migraphx::shape input{migraphx::shape::int64_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4}}, T{3, 0}, input);
} }
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4, 1}, input); throws_shape(T{4}, input);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment