"template/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "93a8daf285af45ed71544e79aae0cb15245e75f4"
Unverified Commit 68032c62 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge pull request #869 from ROCmSoftwarePlatform/scatter-op

Scatter op
parents ed091b14 a197c02e
...@@ -151,6 +151,7 @@ register_migraphx_ops( ...@@ -151,6 +151,7 @@ register_migraphx_ops(
round round
rsqrt rsqrt
scalar scalar
scatter
sigmoid sigmoid
sign sign
sinh sinh
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
auto axis_dim_size = output_shape.lens()[axis];
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -81,6 +81,7 @@ ...@@ -81,6 +81,7 @@
#include <migraphx/op/round.hpp> #include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
...@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
{ {
return contains({"flatten", "gather"}, op_name); return contains({"flatten", "gather", "scatter"}, op_name);
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
......
...@@ -73,6 +73,7 @@ add_library(migraphx_device ...@@ -73,6 +73,7 @@ add_library(migraphx_device
device/rnn_variable_seq_lens.cpp device/rnn_variable_seq_lens.cpp
device/round.cpp device/round.cpp
device/rsqrt.cpp device/rsqrt.cpp
device/scatter.cpp
device/sigmoid.cpp device/sigmoid.cpp
device/sign.cpp device/sign.cpp
device/sin.cpp device/sin.cpp
...@@ -145,8 +146,9 @@ add_library(migraphx_gpu ...@@ -145,8 +146,9 @@ add_library(migraphx_gpu
reverse.cpp reverse.cpp
rnn_variable_seq_lens.cpp rnn_variable_seq_lens.cpp
rocblas.cpp rocblas.cpp
softmax.cpp scatter.cpp
schedule_model.cpp schedule_model.cpp
softmax.cpp
sync_device.cpp sync_device.cpp
target.cpp target.cpp
write_literals.cpp write_literals.cpp
...@@ -204,6 +206,7 @@ register_migraphx_gpu_ops(hip_ ...@@ -204,6 +206,7 @@ register_migraphx_gpu_ops(hip_
reverse reverse
round round
rsqrt rsqrt
scatter
sigmoid sigmoid
sign sign
sinh sinh
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/scatter.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{
auto ds = arg0.get_shape();
auto inds = arg1.get_shape();
auto axis_dim_size = ds.lens()[axis];
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) {
auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data());
gs_launch(stream, ds.elements())([=](auto i) __device__ { output_ptr[i] = data_ptr[i]; });
hip_visit_all(arg1, arg2)([&](auto indices, auto update) {
const auto* upd_ptr = device_cast(update.data());
const auto* indices_ptr = device_cast(indices.data());
gs_launch(stream, inds.elements())([=](auto i) __device__ {
auto out_idx = s1.multi(i);
auto index = indices_ptr[i];
index = index < 0 ? index + axis_dim_size : index;
out_idx[axis] = index;
output[out_idx] = upd_ptr[i];
});
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument scatter(
hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_scatter
{
op::scatter op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::scatter"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -173,6 +173,7 @@ struct miopen_apply ...@@ -173,6 +173,7 @@ struct miopen_apply
add_extend_op("rnn_var_sl_last_output"); add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter");
add_extend_op("softmax"); add_extend_op("softmax");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
......
#include <migraphx/gpu/scatter.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/scatter.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_scatter::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_scatter::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::scatter(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -3424,6 +3424,25 @@ def resize_upsample_pc_test(): ...@@ -3424,6 +3424,25 @@ def resize_upsample_pc_test():
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
@onnx_test
def scatter_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
i = helper.make_tensor_value_info('indices', TensorProto.INT32,
[2, 3, 4, 5])
u = helper.make_tensor_value_info('update', TensorProto.FLOAT,
[2, 3, 4, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 5, 6])
node = onnx.helper.make_node(
'Scatter',
inputs=['data', 'indices', 'update'],
outputs=['y'],
axis=-2,
)
return ([node], [x, i, u], [y])
@onnx_test @onnx_test
def selu_test(): def selu_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3]) x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3])
......
...@@ -3266,6 +3266,23 @@ TEST_CASE(round_test) ...@@ -3266,6 +3266,23 @@ TEST_CASE(round_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(scatter_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3, 4, 5}});
auto l2 =
mm->add_parameter("update", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
int axis = -2;
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", axis}}), l0, l1, l2);
mm->add_return({r});
auto prog = migraphx::parse_onnx("scatter_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(selu_test) TEST_CASE(selu_test)
{ {
migraphx::program p; migraphx::program p;
......
 scatter_test:
9
data
indices
updatey"Scatter*
axis scatter_testZ
data




Z!
indices




Z
update




b
y




B
\ No newline at end of file
...@@ -173,6 +173,8 @@ def create_backend_test(testname=None, target_device=None): ...@@ -173,6 +173,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_reduce.*') backend_test.include(r'.*test_reduce.*')
backend_test.include(r'.*test_ReLU*') backend_test.include(r'.*test_ReLU*')
backend_test.include(r'.*test_relu.*') backend_test.include(r'.*test_relu.*')
backend_test.include(r'.*test_scatter.*')
backend_test.include(r'.*test_Scatter.*')
backend_test.include(r'.*test_selu.*') backend_test.include(r'.*test_selu.*')
backend_test.include(r'.*test_shape.*') backend_test.include(r'.*test_shape.*')
backend_test.include(r'.*test_Sigmoid*') backend_test.include(r'.*test_Sigmoid*')
...@@ -276,6 +278,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -276,6 +278,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_mean_one_input_cpu') backend_test.exclude(r'test_mean_one_input_cpu')
backend_test.exclude(r'test_mean_two_inputs_cpu') backend_test.exclude(r'test_mean_two_inputs_cpu')
backend_test.exclude(r'test_negative_log_likelihood_loss_*') backend_test.exclude(r'test_negative_log_likelihood_loss_*')
backend_test.exclude(r'test_scatternd_*')
# all reduce ops have dynamic axes inputs # all reduce ops have dynamic axes inputs
backend_test.exclude(r'test_size_cpu') backend_test.exclude(r'test_size_cpu')
......
...@@ -3831,6 +3831,84 @@ TEST_CASE(rsqrt_test) ...@@ -3831,6 +3831,84 @@ TEST_CASE(rsqrt_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(scatter_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 0.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", 0}}), ld, li, lu);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 0.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, -1, 0, 2, -2};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -2}}), ld, li, lu);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
std::vector<float> vd(sd.elements(), 0.0f);
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
std::vector<float> vu = {1.0, 1.1, 1.2, 2.0, 2.1, 2.2};
auto ld = mm->add_literal(migraphx::literal{sd, vd});
auto li = mm->add_literal(migraphx::literal{si, vi});
auto lu = mm->add_literal(migraphx::literal{su, vu});
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", 1}}), ld, li, lu);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(sigmoid_test) TEST_CASE(sigmoid_test)
{ {
migraphx::program p; migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatter0 : verify_program<test_scatter0>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -1}}), pd, li, pu);
mm->add_return({r});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatter1 : verify_program<test_scatter1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {-2, 0, 2, 0, -1, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter", {{"axis", -2}}), pd, li, pu);
mm->add_return({r});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment