Commit d7a28300 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge develop branch to branch_for_ort2

parents bcb2c0a4 a930f1d5
...@@ -54,6 +54,10 @@ function(py_add_module NAME) ...@@ -54,6 +54,10 @@ function(py_add_module NAME)
endfunction() endfunction()
set(PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9) set(PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9)
set(PYTHON_DISABLE_VERSIONS "" CACHE STRING "")
foreach(PYTHON_DISABLE_VERSION ${PYTHON_DISABLE_VERSIONS})
list(REMOVE_ITEM PYTHON_SEARCH_VERSIONS ${PYTHON_DISABLE_VERSION})
endforeach()
set(_PYTHON_VERSIONS) set(_PYTHON_VERSIONS)
foreach(PYTHON_VERSION ${PYTHON_SEARCH_VERSIONS}) foreach(PYTHON_VERSION ${PYTHON_SEARCH_VERSIONS})
......
...@@ -106,8 +106,8 @@ ...@@ -106,8 +106,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"if not path.exists(\"./resnet50.onnx\"):\n", "if not path.exists(\"./resnet50.onnx\"):\n",
" !wget https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.onnx?raw=true\n", " !wget https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v2-7.onnx",
" !mv 'resnet50-v2-7.onnx?raw=true' resnet50.onnx" " !mv resnet50-v2-7.onnx resnet50.onnx"
] ]
}, },
{ {
......
...@@ -162,10 +162,12 @@ register_migraphx_ops( ...@@ -162,10 +162,12 @@ register_migraphx_ops(
round round
rsqrt rsqrt
scalar scalar
scatter scatter_add
scatternd_none scatter_mul
scatter_none
scatternd_add scatternd_add
scatternd_mul scatternd_mul
scatternd_none
sigmoid sigmoid
sign sign
sinh sinh
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
namespace migraphx { namespace migraphx {
...@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha ...@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx; auto a_idx = c_idx;
auto b_idx = c_idx; auto b_idx = c_idx;
double s = 0.0; double s = 0.0;
......
...@@ -88,16 +88,16 @@ struct xorshift_generator ...@@ -88,16 +88,16 @@ struct xorshift_generator
template <class T> template <class T>
auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0) auto generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), xorshf96_generator<T>{seed}); std::generate(result.get(), result.get() + s.element_space(), xorshf96_generator<T>{seed});
return result; return result;
} }
template <class T> template <class T>
auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0) auto fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{ {
auto result = make_shared_array<T>(s.elements()); auto result = make_shared_array<T>(s.element_space());
std::generate(result.get(), result.get() + s.elements(), [=] { return value; }); std::generate(result.get(), result.get() + s.element_space(), [=] { return value; });
return result; return result;
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -16,7 +17,17 @@ namespace migraphx { ...@@ -16,7 +17,17 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct scatter // The scatter operator fetches a subset of data given by an index array and then performs a
// reduction operation (add, multiply, or just set the data) on each element returned. We implement
// it as a separate derived struct for each of the three reduction methods. The related operator
// scatterND is a generalization that works on a set of 3 tensors of different ranks. The
// complementary operations are gather/gatherND.
//
// This is a template for deriving child structs from. Each child needs to define
// only a reduction() method. Names are automatically handled by the op_name template.
template <class Derived>
struct scatter : op_name<Derived>
{ {
int64_t axis = 0; int64_t axis = 0;
...@@ -33,29 +44,44 @@ struct scatter ...@@ -33,29 +44,44 @@ struct scatter
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize}};
} }
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3).standard();
return inputs.front(); // If non-packed, this converts to a packed output while preserving permutation of tensor
return inputs.front().with_lens(inputs.front().lens());
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// max dimension in axis auto& self = static_cast<const Derived&>(*this);
// max dimension in each axis
auto axis_dim_size = output_shape.lens()[axis]; auto axis_dim_size = output_shape.lens()[axis];
// cast all arguments as correct type
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
// copy all of data to output
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape(); auto ind_s = indices.get_shape();
// iterate through items in shape
shape_for_each(ind_s, [&](const auto& idx) { shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx; auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
// Overloaded tensor_view::() invokes indexing logic of
// std::size_t shape::index(std::size_t i) const
// which handles nonstandard shapes correctly
auto index = indices(idx.begin(), idx.end());
// normalize negative indexes (may be redundant after using
// normalize_compute_shape())
index = (index < 0) ? index + axis_dim_size : index; index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index; out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
// look up the appropriate locations in output, using idx and out_idx.
// call reduction() method of derived struct to copy and reduce that element
self.reduction()(output(out_idx.begin(), out_idx.end()),
update(idx.begin(), idx.end()));
}); });
}); });
}); });
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_add : scatter<scatter_add>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
// name of this struct is automatically assigned by the op_name<>
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_mul : scatter<scatter_mul>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_none : scatter<scatter_none>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -86,7 +86,9 @@ ...@@ -86,7 +86,9 @@
#include <migraphx/op/round.hpp> #include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter_add.hpp>
#include <migraphx/op/scatter_mul.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/op/scatternd_add.hpp> #include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp> #include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp> #include <migraphx/op/scatternd_mul.hpp>
......
...@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{ {
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
// clang-format off
return {{"Abs", "abs"}, return {{"Abs", "abs"},
{"Acos", "acos"}, {"Acos", "acos"},
{"Acosh", "acosh"}, {"Acosh", "acosh"},
...@@ -36,8 +37,6 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -36,8 +37,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
...@@ -46,6 +45,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -46,6 +45,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Tan", "tan"}, {"Tan", "tan"},
{"Tanh", "tanh"}, {"Tanh", "tanh"},
{"Not", "not"}}; {"Not", "not"}};
// clang-format on
} }
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatter : op_parser<parse_scatter>
{
std::vector<op_desc> operators() const { return {{"ScatterElements"}, {"Scatter"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
operation op;
std::string op_name = "scatter_none";
int axis = 0;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(contains(info.attributes, "reduction"))
{
std::string reduction_att(info.attributes.at("reduction").s());
// check for a valid reduction attribute. We have an operator for each one.
if(not contains({"none", "add", "mul"}, reduction_att))
MIGRAPHX_THROW("PARSE_SCATTER: unsupported reduction mode " + reduction_att);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name = std::string("scatter_") + reduction_att;
}
op = migraphx::make_op(op_name, {{"axis", axis}});
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
......
...@@ -178,10 +178,15 @@ auto hip_vec_visit_all(T&& x, Ts&&... xs) ...@@ -178,10 +178,15 @@ auto hip_vec_visit_all(T&& x, Ts&&... xs)
return [&](auto f) { return [&](auto f) {
auto sx = get_shape(x); auto sx = get_shape(x);
auto lens = sx.lens(); auto lens = sx.lens();
assert(lens.back() % N == 0);
assert(sx.strides().back() == 1);
lens.back() /= N; lens.back() /= N;
shape ssx{sx.type(), lens}; shape vec_sx{sx.type(), lens};
hip_visit_all_impl( hip_visit_all_impl(vec_sx,
ssx, make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }), f, x, xs...); make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f,
x,
xs...);
}; };
} }
......
...@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister); ...@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); } std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); }
bool is_device_ptr(const void* ptr)
{
hipPointerAttribute_t attr;
auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess)
return false;
return attr.memoryType == hipMemoryTypeDevice;
}
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
{ {
size_t free; size_t free;
...@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) ...@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
{ {
if(sz > get_available_gpu_memory()) if(sz > get_available_gpu_memory())
MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz)); MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz));
void* result; void* result = nullptr;
auto status = host ? hipHostMalloc(&result, sz) : hipMalloc(&result, sz); auto status = host ? hipHostMalloc(&result, sz) : hipMalloc(&result, sz);
if(status != hipSuccess) if(status != hipSuccess)
{ {
if(host) if(host)
...@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) ...@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
else else
return allocate_gpu(sz, true); return allocate_gpu(sz, true);
} }
assert(result != nullptr);
return hip_ptr{result}; return hip_ptr{result};
} }
...@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz) ...@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
{ {
gpu_sync(); gpu_sync();
std::vector<T> result(sz); std::vector<T> result(sz);
assert(not is_device_ptr(result.data()));
assert(is_device_ptr(x));
auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost); auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT
...@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false) ...@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{ {
gpu_sync(); gpu_sync();
auto result = allocate_gpu(sz, host); auto result = allocate_gpu(sz, host);
assert(is_device_ptr(result.get()));
assert(not is_device_ptr(x));
auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status)); MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status));
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter_none.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,7 +14,9 @@ struct context; ...@@ -14,7 +14,9 @@ struct context;
struct hip_scatter struct hip_scatter
{ {
op::scatter op; // scatter_none is an exact replacement for previous op::scatter,
// renamed to match an Onnx option. Don't use base class op::scatter
op::scatter_none op;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -22,7 +24,7 @@ struct hip_scatter ...@@ -22,7 +24,7 @@ struct hip_scatter
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "gpu::scatter"; } std::string name() const { return "gpu::scatter_none"; }
shape compute_shape(std::vector<shape> inputs) const; shape compute_shape(std::vector<shape> inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
......
...@@ -190,7 +190,7 @@ struct miopen_apply ...@@ -190,7 +190,7 @@ struct miopen_apply
add_extend_op("rnn_var_sl_last_output"); add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter"); add_extend_op("scatter_none");
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
...@@ -381,6 +381,9 @@ struct miopen_apply ...@@ -381,6 +381,9 @@ struct miopen_apply
}); });
} }
// add_generic_op just constructs the operator with no fields whereas add_extend_op copies over
// the fields Since it doesn't have fields its default constructed
void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); } void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); }
void add_generic_op(const std::string& op_name, const std::string& gpu_name) void add_generic_op(const std::string& op_name, const std::string& gpu_name)
......
...@@ -4381,7 +4381,7 @@ def roialign_test(): ...@@ -4381,7 +4381,7 @@ def roialign_test():
@onnx_test @onnx_test
def scatter_test(): def scatter_add_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32, i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5]) [2, 3, 4, 5])
...@@ -4390,7 +4390,48 @@ def scatter_test(): ...@@ -4390,7 +4390,48 @@ def scatter_test():
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node( node = onnx.helper.make_node(
'Scatter', 'ScatterElements',
reduction='add',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test
def scatter_mul_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'ScatterElements',
reduction='mul',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test
def scatter_none_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'ScatterElements',
reduction='none',
inputs=['data', 'indices', 'update'], inputs=['data', 'indices', 'update'],
outputs=['y'], outputs=['y'],
axis=-2, axis=-2,
......
...@@ -4233,7 +4233,8 @@ TEST_CASE(round_test) ...@@ -4233,7 +4233,8 @@ TEST_CASE(round_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatter_test) // the ScatterElements op has 3 reduction modes, which map to separate reference ops
migraphx::program create_scatter_program(const std::string& scatter_mode, int axis)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -4242,10 +4243,30 @@ TEST_CASE(scatter_test) ...@@ -4242,10 +4243,30 @@ TEST_CASE(scatter_test)
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}}); mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}});
auto l2 = auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2; auto r = mm->add_instruction(migraphx::make_op(scatter_mode, {{"axis", axis}}), l0, l1, l2);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx"); return p;
}
TEST_CASE(scatter_add_test)
{
migraphx::program p = create_scatter_program("scatter_add", -2);
auto prog = migraphx::parse_onnx("scatter_add_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatter_mul_test)
{
migraphx::program p = create_scatter_program("scatter_mul", -2);
auto prog = migraphx::parse_onnx("scatter_mul_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(scatter_none_test)
{
migraphx::program p = create_scatter_program("scatter_none", -2);
auto prog = migraphx::parse_onnx("scatter_none_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
......
 scatter_test: scatter_add_test:
9 V
data data
indices indices
updatey"Scatter* updatey"ScatterElements*
axis scatter_testZ axis*
reduction"addscatter_add_testZ
data data
 
 
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment