Unverified Commit 5d745540 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Dynamic shape support in scatterND ops (#1455)

* Implement dynamic shapes for scatterND operators.
parent d478675c
...@@ -28,44 +28,89 @@ ...@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template <class Derived> template <class Derived>
struct scatternd_op : op_name<Derived> struct scatternd_op : op_name<Derived>
{ {
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this, true}.has(3);
auto r = inputs.front().lens().size(); auto data_shape = inputs.front();
auto q = inputs.at(1).lens().size(); auto index_shape = inputs.at(1);
auto k = inputs.at(1).lens().back(); auto upd_shape = inputs.back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens(); auto r = data_shape.ndim();
auto data_lens = inputs.front().lens(); auto q = index_shape.ndim();
size_t k;
if(index_shape.dynamic())
{
// the rank of the output is a function of k, so k must be fixed.
if(not index_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = index_shape.dyn_dims().back().min;
}
else
k = index_shape.lens().back();
// Checks on the sizes of input tensors
if(q + r != upd_shape.ndim() + k + 1)
MIGRAPHX_THROW("ScatterND: ranks of inputs don't match. " + std::to_string(q) + " + " +
std::to_string(r) + " - " + std::to_string(k) +
" - 1 != " + std::to_string(upd_shape.ndim()));
if(k > r) if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) + MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r)); " is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1))) // Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] " // It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"); // but any dynamic dimension that's compared to a static dimension must be fixed.
auto s = inputs.front(); auto ind_dims = index_shape.to_dynamic().dyn_dims();
if(s.broadcasted()) auto upd_dims = upd_shape.to_dynamic().dyn_dims();
auto data_dims = data_shape.to_dynamic().dyn_dims();
// Check that corresponding portions of tensor shapes match.
if(not(std::equal(ind_dims.begin(), ind_dims.begin() + q - 1, upd_dims.begin()) and
std::equal(data_dims.begin() + k, data_dims.end(), upd_dims.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. Update dimensions must match "
"indices and data.");
if(data_shape.dynamic())
return data_shape;
else if(data_shape.broadcasted())
{ {
return {s.type(), s.lens()}; return {data_shape.type(), data_shape.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return data_shape.with_lens(data_shape.lens());
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived> ...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto updates_std = shape{updates_shape.type(), updates_shape.lens()}; auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size(); auto q = indices_shape.ndim();
auto r = output_shape.lens().size(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived> ...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std::copy(index_start, index_end, out_idx.begin()); std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); });
}); });
}); });
......
...@@ -5968,6 +5968,24 @@ def scatternd_test(): ...@@ -5968,6 +5968,24 @@ def scatternd_test():
return ([node], [data, indices, updates], [output]) return ([node], [data, indices, updates], [output])
@onnx_test()
def scatternd_dyn_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT,
[None, 2, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[None, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[None, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[None, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'])
return ([node], [data, indices, updates], [output])
@onnx_test() @onnx_test()
def selu_test(): def selu_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3]) x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3])
......
...@@ -5768,53 +5768,67 @@ TEST_CASE(scatter_none_test) ...@@ -5768,53 +5768,67 @@ TEST_CASE(scatter_none_test)
TEST_CASE(scatternd_test) TEST_CASE(scatternd_test)
{ {
{ migraphx::program p;
migraphx::program p; auto* mm = p.get_main_module();
auto* mm = p.get_main_module(); auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l0 = auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto l1 = auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); mm->add_return({r});
auto l2 = auto prog = migraphx::parse_onnx("scatternd_test.onnx");
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
{ TEST_CASE(scatternd_dyn_test)
migraphx::program p; {
auto* mm = p.get_main_module(); // dynamic input.
auto l0 = migraphx::program p;
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto* mm = p.get_main_module();
auto l1 = // parameters with dynamic dimensions
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); auto l0 = mm->add_parameter(
auto l2 = "data", migraphx::shape{migraphx::shape::float_type, {{1, 3, 2}, {2, 2}, {2, 2}}});
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); auto l1 = mm->add_parameter(
auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2); "indices", migraphx::shape{migraphx::shape::int64_type, {{2, 1, 2}, {1, 1}, {2, 2}}});
mm->add_return({r}); auto l2 = mm->add_parameter(
auto prog = migraphx::parse_onnx("scatternd_add_test.onnx"); "updates", migraphx::shape{migraphx::shape::float_type, {{2, 1, 2}, {1, 1}, {2, 2}}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["data"] = {{1, 3, 2}, {2, 2}, {2, 2}};
options.map_dyn_input_dims["indices"] = {{2, 1, 2}, {1, 1}, {2, 2}};
options.map_dyn_input_dims["updates"] = {{2, 1, 2}, {1, 1}, {2, 2}};
auto prog = migraphx::parse_onnx("scatternd_dyn_test.onnx", options);
EXPECT(p == prog); EXPECT(p == prog);
} }
{ TEST_CASE(scatternd_add_test)
migraphx::program p; {
auto* mm = p.get_main_module(); migraphx::program p;
auto l0 = auto* mm = p.get_main_module();
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}}); auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto l2 = auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2);
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); mm->add_return({r});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2); auto prog = migraphx::parse_onnx("scatternd_add_test.onnx");
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatternd_mul_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 = mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog);
} }
TEST_CASE(selu_test) TEST_CASE(selu_test)
......
...@@ -2691,27 +2691,145 @@ TEST_CASE(test_gathernd_dynamic8) ...@@ -2691,27 +2691,145 @@ TEST_CASE(test_gathernd_dynamic8)
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is); expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
} }
TEST_CASE(test_scatternd) TEST_CASE(test_scatternd0)
{ {
{ // good
// k > r auto dtype = migraphx::shape::float_type;
auto dtype = migraphx::shape::float_type; auto itype = migraphx::shape::int64_type;
auto itype = migraphx::shape::int64_type; migraphx::shape ds{dtype, {8}};
migraphx::shape ds{dtype, {8}}; migraphx::shape is{itype, {4, 1}};
migraphx::shape is{itype, {4, 2}}; migraphx::shape us{dtype, {4}};
migraphx::shape us{dtype, {4}}; expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us); }
}
{ TEST_CASE(test_scatternd1)
// update.lens != indices.lens[0:q-1] ++ data.lens[k:r-1] {
auto dtype = migraphx::shape::float_type; // good, broadcasted
auto itype = migraphx::shape::int64_type; auto dtype = migraphx::shape::float_type;
migraphx::shape ds{dtype, {8}}; auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {4, 1}}; migraphx::shape ds{dtype, {8}};
migraphx::shape us{dtype, {2, 2}}; migraphx::shape is{itype, {4, 1}, {4, 0}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us); migraphx::shape us{dtype, {4}};
} expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd2)
{
// too many inputs
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}};
migraphx::shape zs{dtype, {4}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us, zs);
}
TEST_CASE(test_scatternd3)
{
// q + r - k - 1 matches upd_lens.size(), but k > r
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {5, 4, 2}};
migraphx::shape us{dtype, {4}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd4)
{
// q + r - k - 1 != upd_lens.size()
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {2, 2}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd5)
{
// dimensions don't match: update.lens != indices.lens[0:q-1]
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8, 3}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {2, 2}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd_dyn0)
{
// one dynamic input, invalid index
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4}};
migraphx::shape is{itype, {4, 13}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape us{dtype, {dd}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd_dyn1)
{
// one dynamic input
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape us{dtype, {dd}};
expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd_dyn2)
{
// one dynamic input and broadcasted data
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}, {0, 1, 1, 0}};
migraphx::shape ds_std{dtype, {2, 3, 1, 4}};
migraphx::shape is{itype, {4, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape us{dtype, {dd}};
expect_shape(ds_std, migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd_dyn3)
{
// one dynamic input and standard, static data
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape is{itype, {4, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape us{dtype, {dd}};
expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd_dyn4)
{
// index is dynamic with last dimension not fixed
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape::dynamic_dimension dd{4, 5, 0};
migraphx::shape is{itype, {dd, dd}};
migraphx::shape us{dtype, {dd}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
TEST_CASE(test_scatternd_dyn5)
{
// dimensions don't match: update.lens != indices.lens[0:q-1]
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dbad{2, 3, 0};
migraphx::shape is{itype, {dd, dd}};
migraphx::shape us{dtype, {dbad}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
} }
TEST_CASE(test_squeeze) TEST_CASE(test_squeeze)
......
...@@ -7242,6 +7242,51 @@ TEST_CASE(scatternd_reduction_test) ...@@ -7242,6 +7242,51 @@ TEST_CASE(scatternd_reduction_test)
} }
} }
TEST_CASE(scatternd_reduction_dyn_test)
{
// reduction = add, with dynamic input shapes
migraphx::program p;
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape::dynamic_dimension dd{3, 6, 0};
migraphx::shape ds{migraphx::shape::float_type, {dd, dd, dd}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {{2, 2, 0}, dd, dd}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
auto xupdates = mm->add_parameter("U", us);
auto scatternd_add_op = migraphx::make_op("scatternd_add");
auto scatternd = mm->add_instruction(scatternd_add_op, xdata, xindex, xupdates);
mm->add_return({scatternd});
p.compile(migraphx::ref::target{});
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 4, 4}}; // data
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6,
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4,
5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std::vector<uint64_t> input_index{0, 2};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {2, 4, 4}}; // updates
std::vector<float> input_updates{5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4};
params["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
params["I"] = migraphx::argument(is, input_index.data());
params["U"] = migraphx::argument(input_fixed_shape1, input_updates.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{6, 7, 8, 9, 11, 12, 13, 14, 15, 14, 13, 12, 12, 11, 10, 9,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(sigmoid_test) TEST_CASE(sigmoid_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment