Unverified Commit 0e2ea59d authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into dynamic_reduce

parents d6cd850b 231d60a2
......@@ -30,6 +30,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -56,13 +57,21 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
check_shapes{inputs, *this, true}.has(1);
const auto& s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0};
return {shape::int64_type, dyn_dims};
}
else
{
auto lens = s0.lens();
lens[axis] = 1;
return {shape::int64_type, lens};
}
}
template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
......@@ -79,19 +88,18 @@ struct argmax
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto data_idx = dyn_out.computed_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
......
......@@ -140,6 +140,20 @@ def argmax_test():
return ([node], [x], [y])
@onnx_test
def argmax_dyn_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, 4, 6])
node = onnx.helper.make_node('ArgMax',
inputs=['x'],
outputs=['y'],
axis=2,
keepdims=0)
return ([node], [x], [y])
@onnx_test
def argmin_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
......
......@@ -181,6 +181,24 @@ TEST_CASE(argmax_test)
EXPECT(p == prog);
}
TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"x",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("argmax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(argmin_test)
{
migraphx::program p;
......
......@@ -81,6 +81,64 @@ void throws_shape(const migraphx::shape&, Ts...)
"An expected shape should not be passed to throws_shape function");
}
TEST_CASE(argmax_axis0)
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmax", {{"axis", 0}}),
input);
}
TEST_CASE(argmax_axis1)
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
TEST_CASE(argmax_axis2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
TEST_CASE(argmax_axis_neg)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", -1}}),
input);
}
TEST_CASE(argmax_axis_outofbounds)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
TEST_CASE(argmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {1, 1, 0}, {4, 4, 0}, {5, 5, 0}}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
TEST_CASE(argmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {3, 3, 0}, {1, 1, 0}, {4, 6, 0}}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
TEST_CASE(binary_dyn_static_error)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
......@@ -2059,42 +2117,6 @@ TEST_CASE(slice_shape)
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(test_argmax)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmax", {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
}
TEST_CASE(test_argmin)
{
{
......
......@@ -326,6 +326,28 @@ TEST_CASE(argmax_test_neg_2)
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 6, 0}, {3, 6, 0}}};
auto dl = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl);
p.compile(migraphx::ref::target{});
std::vector<float> data = {1.2255, 1.6834, -2.0305, -0.3221, 0.4701, 0.2583, 0.7545, 2.5758,
-1.6849, 0.0928, 0.9022, -0.8765, -0.4090, 0.9301, 2.0724, -1.5706,
0.4867, -0.1493, 0.6957, -0.2179, 0.7142, 0.7177, 0.0183, 1.3497};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 3, 4}};
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
EXPECT(migraphx::verify_range(result_vec, res_gold));
}
TEST_CASE(argmin_test_0)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment