Unverified Commit 832f28c6 authored by turneram's avatar turneram Committed by GitHub
Browse files

Add ScatterND operator (#1074)

Add onnx parser and ref and gpu implementations of ONNX op ScatterND
parent bfedcd45
...@@ -162,6 +162,9 @@ register_migraphx_ops( ...@@ -162,6 +162,9 @@ register_migraphx_ops(
rsqrt rsqrt
scalar scalar
scatter scatter
scatternd_none
scatternd_add
scatternd_mul
sigmoid sigmoid
sign sign
sinh sinh
......
...@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = { static const std::vector<std::string> skip_op_names = {"convert",
"convert", "get_tuple_elem", "if", "loop", "roialign"}; "get_tuple_elem",
"if",
"loop",
"roialign",
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_add : scatternd_op<scatternd_add>
{
scatternd_add() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_mul : scatternd_op<scatternd_mul>
{
scatternd_mul() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_none : scatternd_op<scatternd_none>
{
scatternd_none() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct scatternd_op : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto updates_shape = updates.get_shape();
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices.begin() +
indices_shape.index(indices_idx.begin(), indices_idx.end());
auto index_end = index_start + k;
std::vector<std::size_t> out_idx(r, 0);
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
});
});
});
return result;
}
auto init() const {}
scatternd_op() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -87,6 +87,9 @@ ...@@ -87,6 +87,9 @@
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -133,6 +133,7 @@ add_library(migraphx_gpu ...@@ -133,6 +133,7 @@ add_library(migraphx_gpu
compile_hip_code_object.cpp compile_hip_code_object.cpp
compile_pointwise.cpp compile_pointwise.cpp
compile_roialign.cpp compile_roialign.cpp
compile_scatternd.cpp
concat.cpp concat.cpp
convert.cpp convert.cpp
convolution.cpp convolution.cpp
......
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation
compile_scatternd(context&, const std::vector<shape>& io_shapes, const std::string& reduction)
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 1024;
options.global = compute_global(io_shapes.at(1).elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = io_shapes;
options.params += " -DREDUCTION=assign_" + reduction + "{}";
return compile_hip_code_object(scatternd_kernel, options);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_SCATTERND_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_SCATTERND_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
operation
compile_scatternd(context& ctx, const std::vector<shape>& io_shapes, const std::string& reduction);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_SCATTERND_HPP
...@@ -21,6 +21,16 @@ struct greater ...@@ -21,6 +21,16 @@ struct greater
} }
}; };
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
while(first != last)
{
*d_first++ = *first++;
}
return d_first;
}
template <class Iterator, class Compare> template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{ {
......
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{
auto index = make_index();
auto updates_shape = updates_t.get_shape();
index.global_stride(updates_shape.elements(), [&](auto i) {
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto k = indices_shape.lens.back();
auto q = indices_shape.lens.size();
auto updates_idx = updates_shape.multi(i);
auto indices_idx = indices_shape.multi(0);
copy(updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices_t.begin() + indices_shape.index(indices_idx);
auto index_end = index_start + k;
auto out_idx = output_shape.multi(0);
copy(index_start, index_end, out_idx.begin());
copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
f(output_t[out_idx], updates_t[i]);
});
}
} // namespace migraphx
#endif
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <migraphx/gpu/abs.hpp> #include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp> #include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/compile_roialign.hpp> #include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
...@@ -196,6 +197,7 @@ struct miopen_apply ...@@ -196,6 +197,7 @@ struct miopen_apply
add_nms_op(); add_nms_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_roialign(); add_roialign();
add_scatternd();
} }
void copy_params() void copy_params()
...@@ -503,6 +505,60 @@ struct miopen_apply ...@@ -503,6 +505,60 @@ struct miopen_apply
}); });
} }
void add_scatternd()
{
apply_map.emplace("scatternd_none", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "none";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
apply_map.emplace("scatternd_add", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "add";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
apply_map.emplace("scatternd_mul", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "mul";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
}
// replace the loop operator with gpu_loop operator // replace the loop operator with gpu_loop operator
void add_loop_op() void add_loop_op()
{ {
......
...@@ -4152,6 +4152,59 @@ def scatter_test(): ...@@ -4152,6 +4152,59 @@ def scatter_test():
return ([node], [x, i, u], [y]) return ([node], [x, i, u], [y])
@onnx_test
def scatternd_add_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[2, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[2, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'],
reduction="add")
return ([node], [data, indices, updates], [output])
@onnx_test
def scatternd_mul_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[2, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[2, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'],
reduction="mul")
return ([node], [data, indices, updates], [output])
@onnx_test
def scatternd_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[2, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[2, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'])
return ([node], [data, indices, updates], [output])
@onnx_test @onnx_test
def selu_test(): def selu_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3]) x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3])
......
...@@ -3996,6 +3996,57 @@ TEST_CASE(scatter_test) ...@@ -3996,6 +3996,57 @@ TEST_CASE(scatter_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatternd_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 =
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_none"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 =
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_add"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_add_test.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1, 2}});
auto l2 =
mm->add_parameter("updates", migraphx::shape{migraphx::shape::float_type, {2, 1, 2}});
auto r = mm->add_instruction(migraphx::make_op("scatternd_mul"), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatternd_mul_test.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(selu_test) TEST_CASE(selu_test)
{ {
migraphx::program p; migraphx::program p;
......
scatternd_add_test:Î
@
data
indices
updatesoutput" ScatterND*
reduction"add scatternd_add_testZ
data



Z
indices



Z
updates



b
output



B
\ No newline at end of file
scatternd_mul_test:
@
data
indices
updatesoutput" ScatterND*
reduction"mulscatternd_mul_testZ
data



Z
indices



Z
updates



b
output



B
\ No newline at end of file
scatternd_test:
+
data
indices
updatesoutput" ScatterNDscatternd_testZ
data



Z
indices



Z
updates



b
output



B
\ No newline at end of file
...@@ -1432,6 +1432,29 @@ TEST_CASE(test_scalar_nelemnts) ...@@ -1432,6 +1432,29 @@ TEST_CASE(test_scalar_nelemnts)
throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input); throws_shape(migraphx::make_op("scalar", {{"scalar_bcst_dims", {2, 3, 4, 5}}}), input);
} }
TEST_CASE(test_scatternd)
{
{
// k > r
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 2}};
migraphx::shape us{dtype, {4}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
{
// update.lens != indices.lens[0:q-1] ++ data.lens[k:r-1]
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {2, 2}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
}
TEST_CASE(test_squeeze) TEST_CASE(test_squeeze)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment