Commit 973cafd4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

gpu implementation of the scatter operator

parent a884f396
......@@ -72,6 +72,7 @@ add_library(migraphx_device
device/rnn_variable_seq_lens.cpp
device/round.cpp
device/rsqrt.cpp
device/scatter.cpp
device/sigmoid.cpp
device/sign.cpp
device/sin.cpp
......@@ -144,8 +145,9 @@ add_library(migraphx_gpu
reverse.cpp
rnn_variable_seq_lens.cpp
rocblas.cpp
softmax.cpp
scatter.cpp
schedule_model.cpp
softmax.cpp
sync_device.cpp
target.cpp
write_literals.cpp
......@@ -202,6 +204,7 @@ register_migraphx_gpu_ops(hip_
reverse
round
rsqrt
scatter
sigmoid
sign
sinh
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/scatter.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument scatter(hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis)
{
auto ds = arg0.get_shape();
auto inds = arg1.get_shape();
hip_visit_all(result, arg0, inds)([&](auto output, auto data, auto s1) {
// hip_visit_all(result, arg0, arg2, ds)([&](auto output, auto data, auto update, auto s) {
auto* output_ptr = device_cast(output.data());
const auto* data_ptr = device_cast(data.data());
gs_launch(stream, ds.elements(), 256)([=](auto i) __device__ {
output_ptr[i] = data_ptr[i];
});
hip_visit_all(arg1, arg2)([&](auto indices, auto update) {
const auto* upd_ptr = device_cast(update.data());
const auto* indices_ptr = device_cast(indices.data());
gs_launch(stream, inds.elements(), 256)([=](auto i) __device__ {
auto out_idx = s1.multi(i);
out_idx[axis] = indices_ptr[i];
output[out_idx] = upd_ptr[i];
});
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument scatter(hipStream_t stream, argument result, argument arg0, argument arg1, argument arg2, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_SCATTER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_scatter
{
op::scatter op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::scatter"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -172,6 +172,7 @@ struct miopen_apply
add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter");
add_extend_op("softmax");
add_gemm_op<op::dot>("dot");
......
#include <migraphx/gpu/scatter.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/scatter.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_scatter::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_scatter::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::scatter(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment