Commit d15edcb6 authored by Paul's avatar Paul
Browse files

Formatting

parent ce3048d4
...@@ -13,22 +13,22 @@ namespace device { ...@@ -13,22 +13,22 @@ namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis) argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
{ {
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape(); auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens(); auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements(); lens[axis_index] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens}; shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements(); std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input_v) { visit_all(result, arg1)([&](auto output, auto input_v) {
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) { hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
arg2.visit([&](auto indices) { arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto idx = out_comp.multi(i); auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]]; idx[axis_index] = indices_ptr[idx[axis_index]];
output_ptr[i] = input[idx]; output_ptr[i] = input[idx];
}); });
}); });
}); });
......
...@@ -55,10 +55,12 @@ template <class V, class F, class... Ts> ...@@ -55,10 +55,12 @@ template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...}; std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of(types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) if(!std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...}; std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); })) if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); }); [&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
...@@ -68,10 +70,10 @@ template <class V, class F, class... Ts> ...@@ -68,10 +70,10 @@ template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...}; std::initializer_list<std::size_t> ranks = {get_shape(xs).lens().size()...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); })) if(!std::all_of(
ranks.begin(), ranks.end(), [&](std::size_t r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
[&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
...@@ -152,8 +154,11 @@ template <class T, class... Ts> ...@@ -152,8 +154,11 @@ template <class T, class... Ts>
auto hip_visit_views(T&& x, Ts&&... xs) auto hip_visit_views(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
hip_visit_views_impl( hip_visit_views_impl(get_shape(x),
get_shape(x), make_hip_convert_view([](auto v) { return device_cast(v); }), f, x, xs...); make_hip_convert_view([](auto v) { return device_cast(v); }),
f,
x,
xs...);
}; };
} }
......
...@@ -12,9 +12,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const ...@@ -12,9 +12,7 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument hip_gather::compute(context& ctx, argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
const shape&,
const std::vector<argument>& args) const
{ {
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis); return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment