"include/vscode:/vscode.git/clone" did not exist on "6916e3e41267ce2cfbaa59c09a0d1ced538ba7f9"
Commit ce3048d4 authored by Paul's avatar Paul
Browse files

Refactor gather operator

parent 72c188be
...@@ -17,20 +17,18 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg ...@@ -17,20 +17,18 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
auto& input_shape = arg1.get_shape(); auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens(); auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements(); lens[axis_index] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements(); std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input) {
arg2.visit([&](auto indices) { visit_all(result, arg1)([&](auto output, auto input_v) {
const auto* indices_ptr = device_cast(indices.data()); hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
auto* out_ptr = device_cast(output.data()); arg2.visit([&](auto indices) {
const auto* in_ptr = device_cast(input.data()); const auto* indices_ptr = device_cast(indices.data());
migraphx::shape out_comp_shape{result.get_shape().type(), lens}; auto* output_ptr = device_cast(output.data());
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) { gs_launch(stream, nelements)([=](auto i) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape); auto idx = out_comp.multi(i);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape); idx[axis_index] = indices_ptr[idx[axis_index]];
gs_launch(stream, nelements)([=](auto ii) { output_ptr[i] = input[idx];
auto in_idx = desc_output.multi(ii);
in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}); });
}); });
}); });
......
...@@ -44,6 +44,12 @@ hip_tensor_view<T, N> make_hip_view(const shape& s, T* x) ...@@ -44,6 +44,12 @@ hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
return {x, s}; return {x, s};
} }
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip_view(tensor_view<T> x)
{
return {x};
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -54,10 +54,26 @@ auto get_shape(const T& x) -> decltype(x.get_shape()) ...@@ -54,10 +54,26 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of(types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); }); [&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
} }
template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { v(f(xs, ndim)...); });
}
template <class F> template <class F>
struct hip_convert struct hip_convert
{ {
...@@ -82,6 +98,29 @@ hip_convert<F> make_hip_convert(F f) ...@@ -82,6 +98,29 @@ hip_convert<F> make_hip_convert(F f)
return {f}; return {f};
} }
template <class F>
struct hip_convert_view
{
F f;
template <class T, class N>
auto operator()(tensor_view<T> x, N ndim) const
{
return make_hip_view<ndim>(f(x));
}
template <class N>
auto operator()(const shape& s, N ndim) const
{
return make_hip_shape<ndim>(s);
}
};
template <class F>
hip_convert_view<F> make_hip_convert_view(F f)
{
return {f};
}
template <class T, class... Ts> template <class T, class... Ts>
auto hip_visit_all(T&& x, Ts&&... xs) auto hip_visit_all(T&& x, Ts&&... xs)
{ {
...@@ -109,6 +148,15 @@ auto hip_pointer_visit_all(T&& x, Ts&&... xs) ...@@ -109,6 +148,15 @@ auto hip_pointer_visit_all(T&& x, Ts&&... xs)
return [&](auto f) { visit_all(x, xs...)([&](auto... vs) { f(device_cast(vs.data())...); }); }; return [&](auto f) { visit_all(x, xs...)([&](auto... vs) { f(device_cast(vs.data())...); }); };
} }
template <class T, class... Ts>
auto hip_visit_views(T&& x, Ts&&... xs)
{
return [&](auto f) {
hip_visit_views_impl(
get_shape(x), make_hip_convert_view([](auto v) { return device_cast(v); }), f, x, xs...);
};
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -13,7 +13,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const ...@@ -13,7 +13,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
} }
argument hip_gather::compute(context& ctx, argument hip_gather::compute(context& ctx,
const shape& output_shape, const shape&,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis); return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment