Unverified Commit c72d53ba authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Dyn gather (#1513)

Dynamic shape support for gather op.
parent 5bcb7ce8
......@@ -26,6 +26,7 @@
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
......@@ -61,35 +62,59 @@ struct gather
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
auto type = inputs[0].type();
lens.erase(lens.begin() + axis);
if(not inputs[1].scalar())
check_shapes{inputs, *this, true}.has(2);
shape data = inputs[0];
shape indices = inputs[1];
auto type = data.type();
// If index_dims is dynamic, convert the data to dynamic too.
if(indices.dynamic())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
data = data.to_dynamic();
}
// for scalar output
if(lens.empty())
if(data.dynamic())
{
return {type};
auto dims = data.dyn_dims();
dims.erase(dims.begin() + axis);
if(not indices.scalar())
{
auto index_dims = indices.to_dynamic().dyn_dims();
dims.insert(dims.begin() + axis, index_dims.begin(), index_dims.end());
}
return {type, dims};
}
else
{
// Both data and indices are static. indices may be scalar
auto lens = data.lens();
lens.erase(lens.begin() + axis);
return {type, lens};
if(not indices.scalar())
{
auto ind_lens = indices.lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
}
// for scalar output
if(lens.empty())
{
return {type};
}
return {type, lens};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
// negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis];
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
if(dyn_out.computed_shape.scalar())
{
auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
......
......@@ -2053,6 +2053,40 @@ def gather_test():
return ([node], [x, i], [y])
@onnx_test()
def gather_scalar_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32, [])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 5, 6])
node = onnx.helper.make_node(
'Gather',
inputs=['data', 'indices'],
outputs=['y'],
axis=1,
)
return ([node], [x, i], [y])
@onnx_test()
def gather_dyn_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT,
[None, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[None, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3, 4, 5])
node = onnx.helper.make_node(
'Gather',
inputs=['data', 'indices'],
outputs=['y'],
axis=1,
)
return ([node], [x, i], [y])
@onnx_test()
def gather_elements_axis0_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4])
......
......@@ -2048,6 +2048,46 @@ TEST_CASE(gather_test)
EXPECT(p == prog);
}
TEST_CASE(gather_scalar_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
std::vector<size_t> idims{1};
auto l1 =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, idims, {0}});
int axis = 1;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), l0, l1);
auto prog = optimize_onnx("gather_scalar_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gather_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"data",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
auto l1 = mm->add_parameter(
"indices",
migraphx::shape{migraphx::shape::int32_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto cont_l0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto cont_l1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
int axis = 1;
auto gather_op = migraphx::make_op("gather", {{"axis", axis}});
auto ret = mm->add_instruction(gather_op, cont_l0, cont_l1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("gather_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gather_elements_axis0_test)
{
migraphx::program p;
......
......@@ -831,6 +831,77 @@ TEST_CASE(gather)
}
}
TEST_CASE(gather_dyn0)
{
// Insert dynamic index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 7, 3}, {3, 3, 0}}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {2, 7, 3}, {3, 3, 0}, {6, 9, 7}, {12, 14, 13}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(gather_dyn1)
{
// Insert static index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {2, 2, 0}, {3, 3, 0}, {6, 9, 7}, {12, 14, 13}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(gather_dyn2)
{
// Insert scalar (static) index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
std::vector<std::size_t> mins;
std::vector<std::size_t> maxes;
std::vector<std::size_t> opts;
migraphx::shape indices{migraphx::shape::int32_type, mins, maxes, opts};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2, 3, 2}, {6, 9, 7}, {12, 14, 13}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(gather_dyn3)
{
// Insert dynamic index into static shape, axis 1
migraphx::shape input{migraphx::shape::float_type, {2, 3, 6, 12}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, 2}, {3, 4, 3}}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 2, 0}, {2, 3, 2}, {3, 4, 3}, {6, 6, 0}, {12, 12, 0}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(gather_dyn4)
{
// Insert dynamic index into static shape, axis 0
migraphx::shape input{migraphx::shape::float_type, {2, 3, 6, 12}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, 2}, {3, 4, 3}}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {3, 3, 0}, {6, 6, 0}, {12, 12, 0}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(get_tuple_elem_test)
{
migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
......
......@@ -2524,6 +2524,78 @@ TEST_CASE(gather_test)
}
}
TEST_CASE(gather_dyn_test0)
{
// Dynamic data, static indices
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {3, 3, 0}}};
auto x = mm->add_parameter("x", s);
std::vector<int> indices{1, 2};
migraphx::shape s_ind{migraphx::shape::int32_type, {1, 2}};
auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{2, 5, 0}, {1, 1, 0}, {2, 2, 0}}};
EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::ref::target{});
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {2, 3}};
migraphx::shape input_indices{migraphx::shape::int32_type, {1, 2}};
migraphx::parameter_map params;
std::vector<int> data(2 * 3);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
params["indices"] = migraphx::argument(input_indices, indices.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5};
std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
migraphx::shape sfinal{migraphx::shape::int32_type, {2, 1, 2}};
EXPECT(result.get_shape() == sfinal);
}
TEST_CASE(gather_dyn_test1)
{
// Dynamic data, dynamic indices
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {4, 4, 0}}};
auto x = mm->add_parameter("x", s);
migraphx::shape s_ind{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}}};
auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}, {4, 4, 0}}};
EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::ref::target{});
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {3, 4}};
migraphx::shape input_indices_shape{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{2, 0};
migraphx::parameter_map params;
std::vector<int> data(3 * 4);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
params["indices"] = migraphx::argument(input_indices_shape, indices.data());
auto result = p.eval(params).back();
std::vector<int> gold = {8, 9, 10, 11, 0, 1, 2, 3};
std::vector<int> results_vector(1 * 2 * 4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
migraphx::shape sfinal{migraphx::shape::int32_type, {1, 2, 4}};
EXPECT(result.get_shape() == sfinal);
}
TEST_CASE(gathernd_test)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment