"vscode:/vscode.git/clone" did not exist on "ea3003f6efead767cf8381f1d3517ae8c00c24f6"
Unverified Commit d478675c authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Dynamic gathernd (#1480)

Dynamic shape support for gathernd op.
parent dee20c6c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
......@@ -47,33 +48,103 @@ struct gathernd
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
check_shapes{inputs, *this, true}.has(2);
auto i_shape = inputs.back();
auto data_shape = inputs.front();
auto r = data_shape.ndim();
auto q = i_shape.ndim();
size_t k;
if(i_shape.dynamic())
{
// the rank of the output is a function of k, so it must be fixed.
if(not i_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = i_shape.dyn_dims().back().min;
}
else
k = i_shape.lens().back();
// Begin input validation checks.
int output_ndim = int(q) + r - k - batch_dims - 1;
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
if(batch_dims >= q or batch_dims >= r)
{
MIGRAPHX_THROW("GATHERND: rank of an input cannot be less than batch_dims=" +
std::to_string(batch_dims));
}
if(output_ndim < 0)
{
MIGRAPHX_THROW("GATHERND: Indices too large for static data input: k=" +
std::to_string(k));
}
if(migraphx::none_of(inputs, [](auto v) { return v.dynamic(); }))
{
auto indices_lens_iter = i_shape.lens().begin();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape{data_shape.type(), {1}};
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<std::size_t> output_lens(output_ndim);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_lens = data_shape.lens();
std::copy(data_lens.begin() + batch_dims + k,
data_lens.end(),
output_lens.begin() + q - 1);
}
shape output_shape{data_shape.type(), output_lens};
return output_shape;
}
else
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
// If one or both inputs are dynamic shapes, the output is dynamic.
// Make both inputs dynamic to simplify computations.
data_shape = data_shape.to_dynamic();
i_shape = i_shape.to_dynamic();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape(data_shape.type(), {shape::dynamic_dimension({1, 1, 0})});
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<shape::dynamic_dimension> output_dims(output_ndim);
std::copy(i_shape.dyn_dims().begin(),
i_shape.dyn_dims().begin() + q - 1,
output_dims.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_dims = data_shape.dyn_dims();
std::copy(data_dims.begin() + batch_dims + k,
data_dims.begin() + r,
output_dims.begin() + q - 1);
}
shape output_shape(data_shape.type(), output_dims);
return output_shape;
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
......
......@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
"} for parameter: " + param_name +
" should be: " + to_string(ins->get_shape()));
}
return param;
}));
......
......@@ -2132,6 +2132,19 @@ def gathernd_test():
return ([node], [x, i], [y])
@onnx_test()
def gathernd_dyn_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2])
node = onnx.helper.make_node('GatherND',
inputs=['data', 'indices'],
outputs=['y'])
return ([node], [x, i], [y])
@onnx_test()
def gathernd_batch_dims_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
......
......@@ -2158,6 +2158,24 @@ TEST_CASE(gathernd_test)
EXPECT(p == prog);
}
TEST_CASE(gathernd_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data",
migraphx::shape{migraphx::shape::float_type, {{2, 4, 2}, {2, 4}}});
auto l1 = mm->add_parameter("indices",
migraphx::shape{migraphx::shape::int64_type, {{1, 3}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("gathernd"), l0, l1);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["data"] = {{2, 4, 2}, {2, 4}};
options.map_dyn_input_dims["indices"] = {{1, 3}, {2, 2}};
auto prog = migraphx::parse_onnx("gathernd_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gathernd_batch_dims_test)
{
migraphx::program p;
......
......@@ -2477,6 +2477,220 @@ TEST_CASE(test_scalar_nelemnts)
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
}
TEST_CASE(test_gathernd)
{
{
// k > r
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
migraphx::shape ds{dtype, {8}};
int batch_dims(1);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
{
// k > r - batch_dims
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
migraphx::shape ds{dtype, {2}};
int batch_dims(1);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
{
// batch_dims >= r
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
migraphx::shape ds{dtype, {2, 5, 6, 7}};
int batch_dims(3);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
{
// int(q) + r - k - batch_dims - 1 = 0 => returns a scalar
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {1}};
migraphx::shape ds{dtype, {2}};
migraphx::shape s0{dtype, {1}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
{
// See Example 4 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 2}};
migraphx::shape ds{dtype, {2, 2}};
migraphx::shape s0{dtype, {2}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
migraphx::shape ds{dtype, {2, 2, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {2, 2}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
}
TEST_CASE(test_gathernd_dynamic0)
{
// k > r
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}};
migraphx::shape ds{dtype, b};
int batch_dims(1);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic1)
{
// k > r - batch_dims
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}};
migraphx::shape ds{dtype, b};
int batch_dims(1);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic2)
{
// batch_dims >= r
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
migraphx::shape ds{dtype, {{2, 3, 3}, {5, 6, 5}, {6, 9, 7}, {7, 8, 8}}};
int batch_dims(3);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic3)
{
// int(q) + r - k - batch_dims - 1 = 0 => returns a scalar
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {1}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}};
migraphx::shape ds{dtype, b};
migraphx::shape::dynamic_dimension ddout{1, 1, 0};
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
TEST_CASE(test_gathernd_dynamic4)
{
// See Example 1 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 2}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {2, 2, 0}};
migraphx::shape ds{dtype, b};
migraphx::shape::dynamic_dimension ddout{2, 2, 0};
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
TEST_CASE(test_gathernd_dynamic5)
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
// index static shape, data dynamic
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {2, 2, 0}, {2, 2, 0}};
migraphx::shape ds{dtype, b};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2, 0}, {2, 2, 0}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic6)
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
// index dynamic shape, data static
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> b{{2, 3, 0}, {1, 1, 0}};
migraphx::shape is{itype, b};
migraphx::shape ds{dtype, {2, 2, 2}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 3, 0}, {2, 2, 0}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic6a)
{
// indices with non-fixed dynamic dimension k
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {1, 3, 0}};
migraphx::shape is{itype, b};
migraphx::shape ds{dtype, {2, 2, 2}};
int batch_dims(1);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic7)
{
// See Example 5 at https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND
// index and data both dynamic shapes
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> idyn{{2, 5, 0}, {1, 1, 0}};
migraphx::shape is{itype, idyn};
std::vector<migraphx::shape::dynamic_dimension> bdyn{{1, 2, 0}, {1, 2, 0}, {1, 2, 0}};
migraphx::shape ds{dtype, bdyn};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 5, 0}, {1, 2, 0}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_gathernd_dynamic8)
{
// Same shapes as ref_ops_test gathernd_dynamic
// index static shape, data dynamic
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 5, 1}};
std::vector<migraphx::shape::dynamic_dimension> b{{6, 7, 7}, {3, 3, 0}, {1, 4, 0}};
migraphx::shape ds{dtype, b};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2, 0}, {5, 5, 0}, {1, 4, 0}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
}
TEST_CASE(test_scatternd)
{
{
......
......@@ -2746,6 +2746,187 @@ TEST_CASE(gathernd_test)
}
}
TEST_CASE(gathernd_dynamic0)
{
// dynamic data, all dimensions fixed
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 2, 2}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic1)
{
// dynamic data, dims not fixed
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic2)
{
// dynamic both index and data
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic3)
{
// dynamic index, static data and a batch_dims input
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
int batch_dims{1};
auto gathernd_op = migraphx::make_op("gathernd", {{"batch_dims", batch_dims}});
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 3, 1}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {2, 2, 1}}; // index
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{1, 0, 3, 4};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_dynamic4)
{
// int(q) + r - k - batch_dims - 1 = 0 => returns a scalar
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type,
{migraphx::shape::dynamic_dimension({2, 2, 0})}};
migraphx::shape is{migraphx::shape::int64_type, {1}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto gathernd_op = migraphx::make_op("gathernd");
auto gathernd = mm->add_instruction(gathernd_op, xdata, xindex);
mm->add_return({gathernd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2}}; // data
migraphx::shape input_fixed_shape1{migraphx::shape::int64_type, {1}}; // index
std::vector<float> data_vec(2);
std::iota(data_vec.begin(), data_vec.end(), 4);
std::vector<int64_t> indices_vec{1};
params["X"] = migraphx::argument(input_fixed_shape0, data_vec.data());
params["I"] = migraphx::argument(input_fixed_shape1, indices_vec.data());
auto result = p.eval(params).back();
std::vector<float> res_data{};
std::vector<float> gold{5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
TEST_CASE(gathernd_negative_index_test)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment