"vscode:/vscode.git/clone" did not exist on "acd7d03266e1b2b1df07c608ba225eb46a57d4cf"
Unverified Commit 6ae4227a authored by Zakor Gyula's avatar Zakor Gyula Committed by GitHub
Browse files

Add support for select_last_index attribute for ArgMax & ArgMin (#2235)

parent f47e0b5b
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,11 +40,12 @@ namespace op { ...@@ -39,11 +40,12 @@ namespace op {
struct argmax struct argmax
{ {
int64_t axis = 0; int64_t axis = 0;
bool select_last_index = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.select_last_index, "select_last_index"));
} }
value attributes() const value attributes() const
...@@ -87,6 +89,10 @@ struct argmax ...@@ -87,6 +89,10 @@ struct argmax
max_val = cur_val; max_val = cur_val;
max_index = i; max_index = i;
} }
else if(select_last_index and float_equal(max_val, cur_val))
{
max_index = i;
}
} }
return max_index; return max_index;
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,11 +39,12 @@ namespace op { ...@@ -38,11 +39,12 @@ namespace op {
struct argmin struct argmin
{ {
int64_t axis = 0; int64_t axis = 0;
bool select_last_index = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.select_last_index, "select_last_index"));
} }
value attributes() const value attributes() const
...@@ -78,6 +80,10 @@ struct argmin ...@@ -78,6 +80,10 @@ struct argmin
min_val = cur_val; min_val = cur_val;
min_index = i; min_index = i;
} }
else if(select_last_index and float_equal(min_val, cur_val))
{
min_index = i;
}
} }
return min_index; return min_index;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op> ...@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>(); keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
} }
bool select_last_index = false;
if(contains(info.attributes, "select_last_index"))
{
select_last_index = static_cast<bool>(
parser.parse_value(info.attributes.at("select_last_index")).at<int>());
}
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); auto ins = info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins); return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
} }
else else
{ {
return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); return info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
} }
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmax(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum ...@@ -40,7 +40,8 @@ argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argum
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name()); int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmin(
ctx.get_stream().get(), args.back(), args.front(), tuned_axis, op.select_last_index);
return args.back(); return args.back();
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmax(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmax_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmax_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmax_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,9 +34,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void argmin(hipStream_t stream,
const argument& result,
const argument& arg,
int64_t axis,
bool select_last_index)
{ {
arg_op(argmin_op{}, stream, result, arg, axis); if(select_last_index)
arg_op(argmin_op_last_index{}, stream, result, arg, axis);
else
arg_op(argmin_op_first_index{}, stream, result, arg, axis);
} }
} // namespace device } // namespace device
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i) ...@@ -55,7 +55,7 @@ MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
return {v, i}; return {v, i};
} }
struct argmax_op struct argmax_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -73,7 +73,25 @@ struct argmax_op ...@@ -73,7 +73,25 @@ struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
}; };
struct argmin_op struct argmax_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op_first_index
{ {
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
...@@ -91,6 +109,24 @@ struct argmin_op ...@@ -91,6 +109,24 @@ struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
}; };
struct argmin_op_last_index
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index > y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op> template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,7 +36,8 @@ namespace device { ...@@ -36,7 +36,8 @@ namespace device {
void MIGRAPHX_DEVICE_EXPORT argmax(hipStream_t stream, void MIGRAPHX_DEVICE_EXPORT argmax(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
int64_t axis); int64_t axis,
bool select_last_index);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,7 +36,8 @@ namespace device { ...@@ -36,7 +36,8 @@ namespace device {
void MIGRAPHX_DEVICE_EXPORT argmin(hipStream_t stream, void MIGRAPHX_DEVICE_EXPORT argmin(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
int64_t axis); int64_t axis,
bool select_last_index);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -149,6 +149,21 @@ def argmax_test(): ...@@ -149,6 +149,21 @@ def argmax_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def argmax_select_last_index_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 6])
node = onnx.helper.make_node('ArgMax',
inputs=['x'],
outputs=['y'],
axis=2,
keepdims=0,
select_last_index=1)
return ([node], [x], [y])
@onnx_test() @onnx_test()
def argmax_dyn_test(): def argmax_dyn_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 4, 5, 6]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 4, 5, 6])
...@@ -177,6 +192,21 @@ def argmin_test(): ...@@ -177,6 +192,21 @@ def argmin_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def argmin_select_last_index_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5])
node = onnx.helper.make_node('ArgMin',
inputs=['x'],
outputs=['y'],
axis=3,
keepdims=0,
select_last_index=1)
return ([node], [x], [y])
@onnx_test() @onnx_test()
def asin_test(): def asin_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
......
...@@ -184,6 +184,19 @@ TEST_CASE(argmax_test) ...@@ -184,6 +184,19 @@ TEST_CASE(argmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(argmax_select_last_index_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(
migraphx::make_op("argmax", {{"axis", 2}, {"select_last_index", true}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
auto prog = optimize_onnx("argmax_select_last_index_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(argmax_dyn_test) TEST_CASE(argmax_dyn_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -213,6 +226,19 @@ TEST_CASE(argmin_test) ...@@ -213,6 +226,19 @@ TEST_CASE(argmin_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(argmin_select_last_index_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = mm->add_instruction(
migraphx::make_op("argmin", {{"axis", 3}, {"select_last_index", true}}), l0);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), ins);
auto prog = optimize_onnx("argmin_select_last_index_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(asin_test) TEST_CASE(asin_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -66,16 +66,6 @@ class MIGraphXBackendTest(onnx.backend.test.BackendTest): ...@@ -66,16 +66,6 @@ class MIGraphXBackendTest(onnx.backend.test.BackendTest):
def disabled_tests_onnx_1_7_0(backend_test): def disabled_tests_onnx_1_7_0(backend_test):
# fails # fails
# from OnnxBackendNodeModelTest # from OnnxBackendNodeModelTest
backend_test.exclude(r'test_argmax_keepdims_example_select_last_index_cpu')
backend_test.exclude(
r'test_argmax_negative_axis_keepdims_example_select_last_index_cpu')
backend_test.exclude(
r'test_argmax_no_keepdims_example_select_last_index_cpu')
backend_test.exclude(r'test_argmin_keepdims_example_select_last_index_cpu')
backend_test.exclude(
r'test_argmin_negative_axis_keepdims_example_select_last_index_cpu')
backend_test.exclude(
r'test_argmin_no_keepdims_example_select_last_index_cpu')
backend_test.exclude(r'test_logsoftmax_axis_0_cpu') backend_test.exclude(r'test_logsoftmax_axis_0_cpu')
backend_test.exclude(r'test_logsoftmax_axis_1_cpu') backend_test.exclude(r'test_logsoftmax_axis_1_cpu')
backend_test.exclude(r'test_logsoftmax_default_axis_cpu') backend_test.exclude(r'test_logsoftmax_default_axis_cpu')
......
...@@ -147,3 +147,37 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -147,3 +147,37 @@ TEST_CASE(argmax_test_nonstd_shape)
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec));
} }
TEST_CASE(argmax_test_select_last_index_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {2.0305, -1.853, 2.0305, -1.5706, 0.7545, 0.7545};
std::vector<int64_t> res_gold = {2, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 1}, {"select_last_index", true}}),
dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmax_test_select_last_index_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {2.0305, -1.853, 2.0305, -1.5706, 0.7545, 0.7545};
std::vector<int64_t> res_gold = {0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 1}, {"select_last_index", false}}),
dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
...@@ -125,3 +125,37 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -125,3 +125,37 @@ TEST_CASE(argmin_test_nonstd_shape)
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec));
} }
TEST_CASE(argmin_test_select_last_index_0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {-2.0305, 0.853, -2.0305, 1.5706, 0.7545, 0.7545};
std::vector<int64_t> res_gold = {2, 2};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmin", {{"axis", 1}, {"select_last_index", true}}),
dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_select_last_index_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data = {-2.0305, 0.853, -2.0305, 1.5706, 0.7545, 0.7545};
std::vector<int64_t> res_gold = {0, 1};
migraphx::shape data_shape{migraphx::shape::float_type, {2, 3}};
auto dl = mm->add_literal(migraphx::literal{data_shape, data});
mm->add_instruction(migraphx::make_op("argmin", {{"axis", 1}, {"select_last_index", false}}),
dl);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
}
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
template <class T, int Axis, int NonStdShape> template <class T, int Axis, bool LastIndex, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis, LastIndex, NonStdShape>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>> ...@@ -54,63 +54,111 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, NonStdShape>>
break; break;
default: break; default: break;
} }
mm->add_instruction(T{Axis}, param); mm->add_instruction(T{Axis, LastIndex}, param);
return p; return p;
} }
}; };
// transpose argmax tests // transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1, 0>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, 0>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, 0>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, 0>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, 0>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 0>;
// transpose argmin tests // transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1, 0>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2, 0>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, 0>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, 0>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, 0>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 0>;
// broadcast argmax tests // broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 1>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, 1>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, 1>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, 1>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, 1>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 1>;
// broadcast argmin tests // broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 1>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, 1>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, 1>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, 1>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, 1>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 1>;
// slice argmax tests // slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 2>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, 2>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, 2>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, 2>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, 2>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 2>;
// slice argmin tests // slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 2>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, 2>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, 2>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, 2>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, 2>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 2>;
// default case, standard shape argmax tests // default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, 3>; template struct test_arg_ops<migraphx::op::argmax, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, 3>; template struct test_arg_ops<migraphx::op::argmax, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, 3>; template struct test_arg_ops<migraphx::op::argmax, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, 3>; template struct test_arg_ops<migraphx::op::argmax, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, 3>; template struct test_arg_ops<migraphx::op::argmax, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, 3>; template struct test_arg_ops<migraphx::op::argmax, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 3>;
// default case, standard shape argmin tests // default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, 3>; template struct test_arg_ops<migraphx::op::argmin, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, 3>; template struct test_arg_ops<migraphx::op::argmin, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, 3>; template struct test_arg_ops<migraphx::op::argmin, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, 3>; template struct test_arg_ops<migraphx::op::argmin, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, 3>; template struct test_arg_ops<migraphx::op::argmin, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, 3>; template struct test_arg_ops<migraphx::op::argmin, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 3>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment