Unverified Commit 701c2014 authored by bpickrel's avatar bpickrel Committed by GitHub
Browse files

scatter operator refactoring to include reduction (#1124)

Change the "scatter" struct and op to a base/child set of three: scatter_none, scatter_add, scatter_mul to mirror Onnx' ScatterElements op. and its three reduction options. (Onnx Scatter op is deprecated and is equivalent to scatter_none.)

Provides both a reference op. and update to Onnx parsing. Tests updated and new test case added.
parent 3c301efa
...@@ -162,10 +162,12 @@ register_migraphx_ops( ...@@ -162,10 +162,12 @@ register_migraphx_ops(
round round
rsqrt rsqrt
scalar scalar
scatter scatter_add
scatternd_none scatter_mul
scatter_none
scatternd_add scatternd_add
scatternd_mul scatternd_mul
scatternd_none
sigmoid sigmoid
sign sign
sinh sinh
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -16,7 +17,17 @@ namespace migraphx { ...@@ -16,7 +17,17 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct scatter // The scatter operator fetches a subset of data given by an index array and then performs a
// reduction operation (add, multiply, or just set the data) on each element returned. We implement
// it as a separate derived struct for each of the three reduction methods. The related operator
// scatterND is a generalization that works on a set of 3 tensors of different ranks. The
// complementary operations are gather/gatherND.
//
// This is a template for deriving child structs from. Each child needs to define
// only a reduction() method. Names are automatically handled by the op_name template.
template <class Derived>
struct scatter : op_name<Derived>
{ {
int64_t axis = 0; int64_t axis = 0;
...@@ -33,29 +44,44 @@ struct scatter ...@@ -33,29 +44,44 @@ struct scatter
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize}};
} }
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3).standard();
return inputs.front(); // If non-packed, this converts to a packed output while preserving permutation of tensor
return inputs.front().with_lens(inputs.front().lens());
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// max dimension in axis auto& self = static_cast<const Derived&>(*this);
// max dimension in each axis
auto axis_dim_size = output_shape.lens()[axis]; auto axis_dim_size = output_shape.lens()[axis];
// cast all arguments as correct type
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
// copy all of data to output
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape(); auto ind_s = indices.get_shape();
// iterate through items in shape
shape_for_each(ind_s, [&](const auto& idx) { shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx; auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
// Overloaded tensor_view::() invokes indexing logic of
// std::size_t shape::index(std::size_t i) const
// which handles nonstandard shapes correctly
auto index = indices(idx.begin(), idx.end());
// normalize negative indexes (may be redundant after using
// normalize_compute_shape())
index = (index < 0) ? index + axis_dim_size : index; index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index; out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
// look up the appropriate locations in output, using idx and out_idx.
// call reduction() method of derived struct to copy and reduce that element
self.reduction()(output(out_idx.begin(), out_idx.end()),
update(idx.begin(), idx.end()));
}); });
}); });
}); });
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_add : scatter<scatter_add>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
// name of this struct is automatically assigned by the op_name<>
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_mul : scatter<scatter_mul>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_none : scatter<scatter_none>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -86,7 +86,9 @@ ...@@ -86,7 +86,9 @@
#include <migraphx/op/round.hpp> #include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp> #include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp> #include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp> #include <migraphx/op/scatternd_mul.hpp>
......
...@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{ {
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
// clang-format off
return {{"Abs", "abs"}, return {{"Abs", "abs"},
{"Acos", "acos"}, {"Acos", "acos"},
{"Acosh", "acosh"}, {"Acosh", "acosh"},
...@@ -37,8 +38,6 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -37,8 +38,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
...@@ -47,6 +46,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -47,6 +46,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Tan", "tan"}, {"Tan", "tan"},
{"Tanh", "tanh"}, {"Tanh", "tanh"},
{"Not", "not"}}; {"Not", "not"}};
// clang-format on
} }
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatter : op_parser<parse_scatter>
{
std::vector<op_desc> operators() const { return {{"ScatterElements"}, {"Scatter"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
operation op;
std::string op_name = "scatter_none";
int axis = 0;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(contains(info.attributes, "reduction"))
{
std::string reduction_att(info.attributes.at("reduction").s());
// check for a valid reduction attribute. We have an operator for each one.
if(not contains({"none", "add", "mul"}, reduction_att))
MIGRAPHX_THROW("PARSE_SCATTER: unsupported reduction mode " + reduction_att);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name = std::string("scatter_") + reduction_att;
}
op = migraphx::make_op(op_name, {{"axis", axis}});
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter_none.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,7 +14,9 @@ struct context; ...@@ -14,7 +14,9 @@ struct context;
struct hip_scatter struct hip_scatter
{ {
op::scatter op; // scatter_none is an exact replacement for previous op::scatter,
// renamed to match an Onnx option. Don't use base class op::scatter
op::scatter_none op;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -22,7 +24,7 @@ struct hip_scatter ...@@ -22,7 +24,7 @@ struct hip_scatter
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "gpu::scatter"; } std::string name() const { return "gpu::scatter_none"; }
shape compute_shape(std::vector<shape> inputs) const; shape compute_shape(std::vector<shape> inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
......
...@@ -190,7 +190,7 @@ struct miopen_apply ...@@ -190,7 +190,7 @@ struct miopen_apply
add_extend_op("rnn_var_sl_last_output"); add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter"); add_extend_op("scatter_none");
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
...@@ -381,6 +381,9 @@ struct miopen_apply ...@@ -381,6 +381,9 @@ struct miopen_apply
}); });
} }
// add_generic_op just constructs the operator with no fields whereas add_extend_op copies over
// the fields Since it doesn't have fields its default constructed
void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); } void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); }
void add_generic_op(const std::string& op_name, const std::string& gpu_name) void add_generic_op(const std::string& op_name, const std::string& gpu_name)
......
...@@ -4381,7 +4381,7 @@ def roialign_test(): ...@@ -4381,7 +4381,7 @@ def roialign_test():
@onnx_test @onnx_test
def scatter_test(): def scatter_add_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32, i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5]) [2, 3, 4, 5])
...@@ -4390,7 +4390,48 @@ def scatter_test(): ...@@ -4390,7 +4390,48 @@ def scatter_test():
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node( node = onnx.helper.make_node(
'Scatter', 'ScatterElements',
reduction='add',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test
def scatter_mul_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'ScatterElements',
reduction='mul',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test
def scatter_none_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'ScatterElements',
reduction='none',
inputs=['data', 'indices', 'update'], inputs=['data', 'indices', 'update'],
outputs=['y'], outputs=['y'],
axis=-2, axis=-2,
......
...@@ -4233,7 +4233,8 @@ TEST_CASE(round_test) ...@@ -4233,7 +4233,8 @@ TEST_CASE(round_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatter_test) // the ScatterElements op has 3 reduction modes, which map to separate reference ops
migraphx::program create_scatter_program(const std::string& scatter_mode, int axis)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4242,10 +4243,30 @@ TEST_CASE(scatter_test) ...@@ -4242,10 +4243,30 @@ TEST_CASE(scatter_test)
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}}); mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}});
auto l2 = auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2; auto r = mm->add_instruction(migraphx::make_op(scatter_mode, {{"axis", axis}}), l0, l1, l2);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx"); return p;
}
TEST_CASE(scatter_add_test)
{
migraphx::program p = create_scatter_program("scatter_add", -2);
auto prog = migraphx::parse_onnx("scatter_add_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatter_mul_test)
{
migraphx::program p = create_scatter_program("scatter_mul", -2);
auto prog = migraphx::parse_onnx("scatter_mul_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatter_none_test)
{
migraphx::program p = create_scatter_program("scatter_none", -2);
auto prog = migraphx::parse_onnx("scatter_none_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
 scatter_test: scatter_add_test:
9 V
data data
indices indices
updatey"Scatter* updatey"ScatterElements*
axis scatter_testZ axis*
reduction"addscatter_add_testZ
data data
 
 
......
scatter_mul_test:
V
data
indices
updatey"ScatterElements*
axis*
reduction"mulscatter_mul_testZ
data




Z!
indices




Z
update




b
y




B
\ No newline at end of file
scatter_none_test:
W
data
indices
updatey"ScatterElements*
axis*
reduction"nonescatter_none_testZ
data




Z!
indices




Z
update




b
y




B
\ No newline at end of file
...@@ -4179,9 +4179,9 @@ TEST_CASE(rsqrt_test) ...@@ -4179,9 +4179,9 @@ TEST_CASE(rsqrt_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(scatter_test) // reduction_mode: "scatter_none", "scatter_add", "scatter_mul"
migraphx::program create_scatter_program(const std::string& reduction_mode, int axis)
{ {
{
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
...@@ -4196,8 +4196,32 @@ TEST_CASE(scatter_test) ...@@ -4196,8 +4196,32 @@ TEST_CASE(scatter_test)
auto ld = mm->add_literal(migraphx::literal{sd, vd}); auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu}); auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", 0}}), ld, li, lu); // scatter_none, formerly the scatter op
auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu);
mm->add_return({r}); mm->add_return({r});
return p;
}
TEST_CASE(scatter_ax0_test)
{
// this tests what used to be the only scatter op, now changed to 3 sub-ops
// which have their own test case
{
migraphx::program p = create_scatter_program("scatter_none", 0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(scatter_ax_neg_test)
{
{
migraphx::program p = create_scatter_program("scatter_none", -2);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
...@@ -4205,55 +4229,177 @@ TEST_CASE(scatter_test) ...@@ -4205,55 +4229,177 @@ TEST_CASE(scatter_test)
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
}
TEST_CASE(scatter_ax1_test)
{
{ {
migraphx::program p = create_scatter_program("scatter_none", 1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
// similar to create_scatter_program but with different tensor values
// reduction_mode: "scatter_none", "scatter_add", "scatter_mul"
migraphx::program create_scatter_program2(const std::string& reduction_mode, int axis)
{
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; migraphx::shape sd{migraphx::shape::float_type, {1, 5}};
std::vector<float> vd(sd.elements(), 0.0f); std::vector<float> vd({1., 2., 3., 4., 5.});
migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; migraphx::shape si{migraphx::shape::int32_type, {1, 2}};
std::vector<int> vi = {1, 0, -1, 0, 2, -2}; std::vector<int> vi = {1, 3};
migraphx::shape su{migraphx::shape::float_type, {2, 3}}; migraphx::shape su{migraphx::shape::float_type, {1, 2}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2}; std::vector<float> vu = {1.1, 2.1};
auto ld = mm->add_literal(migraphx::literal{sd, vd}); auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu}); auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -2}}), ld, li, lu); auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu);
mm->add_return({r}); mm->add_return({r});
return p;
}
TEST_CASE(scatter_reduction1_test)
{
{
// Test sub-ops for the three reduction values scatter_none, scatter_add, scatter_mul
migraphx::program p = create_scatter_program2("scatter_none", 1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; std::vector<float> gold_none = {1.0, 1.1, 3.0, 2.1, 5.0};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold_none));
} }
}
TEST_CASE(scatter_reduction2_test)
{
{
migraphx::program p = create_scatter_program2("scatter_mul", 1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0};
EXPECT(migraphx::verify_range(results_vector, gold_mul));
}
}
TEST_CASE(scatter_reduction3_test)
{
{
migraphx::program p = create_scatter_program2("scatter_add", 1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_add = {1.0, 3.1, 3.0, 6.1, 5.0};
EXPECT(migraphx::verify_range(results_vector, gold_add));
}
}
TEST_CASE(scatter_reduction_3x3_test)
{
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 0.0f); std::vector<float> vd(sd.elements(), 3.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1}; std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}}; migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2}; std::vector<float> vu = {1.0, 1.1, 1.2, 7.0, 7.1, 7.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd}); auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu}); auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", 1}}), ld, li, lu); auto r = mm->add_instruction(migraphx::make_op("scatter_add", {{"axis", 1}}), ld, li, lu);
mm->add_return({r}); mm->add_return({r});
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0}; std::vector<float> gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0};
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(results_vector, gold_a2));
}
}
// create a test scatter program with a 3x3 tensor;
// su and si are transposed from previous case
migraphx::program create_scatter_program_3x3(const std::string& reduction_mode, int axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 3.0f);
migraphx::shape si{migraphx::shape::int32_type, {3, 2}};
std::vector<int> vi = {1, 0, 0, 2, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {3, 2}};
std::vector<float> vu = {1.0, 7.0, 1.1, 7.1, 1.2, 7.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op(reduction_mode, {{"axis", axis}}), ld, li, lu);
mm->add_return({r});
return p;
}
TEST_CASE(scatter_reduction_3x3_xpose1_test)
{
// test on vertical (0) axis. su and si are transposed from previous case
{
migraphx::program p = create_scatter_program_3x3("scatter_none", 0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0};
EXPECT(migraphx::verify_range(results_vector, gold_none2));
}
}
TEST_CASE(scatter_reduction_3x3_xpose2_test)
{
// test on vertical (0) axis.
{
migraphx::program p = create_scatter_program_3x3("scatter_add", 0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0};
EXPECT(migraphx::verify_range(results_vector, gold_a3));
}
}
TEST_CASE(scatter_reduction_3x3_xpose3_test)
{
{
migraphx::program p = create_scatter_program_3x3("scatter_mul", 0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0};
EXPECT(migraphx::verify_range(results_vector, gold_mul2));
} }
} }
......
...@@ -18,7 +18,7 @@ struct test_scatter0 : verify_program<test_scatter0> ...@@ -18,7 +18,7 @@ struct test_scatter0 : verify_program<test_scatter0>
auto pd = mm->add_parameter("data", sd); auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su); auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -1}}), pd, li, pu); auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -1}}), pd, li, pu);
mm->add_return({r}); mm->add_return({r});
return p; return p;
......
...@@ -19,7 +19,7 @@ struct test_scatter1 : verify_program<test_scatter1> ...@@ -19,7 +19,7 @@ struct test_scatter1 : verify_program<test_scatter1>
auto pd = mm->add_parameter("data", sd); auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi}); auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su); auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -2}}), pd, li, pu); auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -2}}), pd, li, pu);
mm->add_return({r}); mm->add_return({r});
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment